Commit 748eceae authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Merge branch 'master' into cifar10_experiment

parents 40e906d2 ed65b632
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the Domain Adaptation via Style Transfer (PixelDA) model components.
A number of details in the implementation make reference to one of the following
works:
- "Unsupervised Representation Learning with Deep Convolutional
Generative Adversarial Networks""
https://arxiv.org/abs/1511.06434
This paper makes several architecture recommendations:
1. Use strided convs in discriminator, fractional-strided convs in generator
2. batchnorm everywhere
3. remove fully connected layers for deep models
4. ReLu for all layers in generator, except tanh on output
5. LeakyReLu for everything in discriminator
"""
import functools
import math
# Dependency imports
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
def create_model(hparams,
target_images,
source_images=None,
source_labels=None,
is_training=False,
noise=None,
num_classes=None):
"""Create a GAN model.
Arguments:
hparams: HParam object specifying model params
target_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_labels: A `Tensor` of size [batch_size] of categorical labels between
[0, num_classes]
is_training: whether model is currently training
noise: If None, model generates its own noise. Otherwise use provided.
num_classes: Number of classes for classification
Returns:
end_points dict with model outputs
Raises:
ValueError: unknown hparams.arch setting
"""
if num_classes is None and hparams.arch in ['resnet', 'simple']:
raise ValueError('Num classes must be provided to create task classifier')
if target_images.dtype != tf.float32:
raise ValueError('target_images must be tf.float32 and [-1, 1] normalized.')
if source_images is not None and source_images.dtype != tf.float32:
raise ValueError('source_images must be tf.float32 and [-1, 1] normalized.')
###########################
# Create latent variables #
###########################
latent_vars = dict()
if hparams.noise_channel:
noise_shape = [hparams.batch_size, hparams.noise_dims]
if noise is not None:
assert noise.shape.as_list() == noise_shape
tf.logging.info('Using provided noise')
else:
tf.logging.info('Using random noise')
noise = tf.random_uniform(
shape=noise_shape,
minval=-1,
maxval=1,
dtype=tf.float32,
name='random_noise')
latent_vars['noise'] = noise
####################
# Create generator #
####################
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose, slim.fully_connected],
normalizer_params=batch_norm_params(is_training,
hparams.batch_norm_decay),
weights_initializer=tf.random_normal_initializer(
stddev=hparams.normal_init_std),
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if hparams.arch == 'dcgan':
end_points = dcgan(
target_images, latent_vars, hparams, scope='generator')
elif hparams.arch == 'resnet':
end_points = resnet_generator(
source_images,
target_images.shape.as_list()[1:4],
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'residual_interpretation':
end_points = residual_interpretation_generator(
source_images, is_training=is_training, hparams=hparams)
elif hparams.arch == 'simple':
end_points = simple_generator(
source_images,
target_images,
is_training=is_training,
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'identity':
# Pass through unmodified, besides changing # channels
# Used to calculate baseline numbers
# Also set `generator_steps=0` for baseline
if hparams.generator_steps:
raise ValueError('Must set generator_steps=0 for identity arch. Is %s'
% hparams.generator_steps)
transferred_images = source_images
source_channels = source_images.shape.as_list()[-1]
target_channels = target_images.shape.as_list()[-1]
if source_channels == 1 and target_channels == 3:
transferred_images = tf.tile(source_images, [1, 1, 1, 3])
if source_channels == 3 and target_channels == 1:
transferred_images = tf.image.rgb_to_grayscale(source_images)
end_points = {'transferred_images': transferred_images}
else:
raise ValueError('Unknown architecture: %s' % hparams.arch)
#####################
# Domain Classifier #
#####################
if hparams.arch in [
'dcgan', 'resnet', 'residual_interpretation', 'simple', 'identity',
]:
# Add a discriminator for these architectures
end_points['transferred_domain_logits'] = predict_domain(
end_points['transferred_images'],
hparams,
is_training=is_training,
reuse=False)
end_points['target_domain_logits'] = predict_domain(
target_images,
hparams,
is_training=is_training,
reuse=True)
###################
# Task Classifier #
###################
if hparams.task_tower != 'none' and hparams.arch in [
'resnet', 'residual_interpretation', 'simple', 'identity',
]:
with tf.variable_scope('discriminator'):
with tf.variable_scope('task_tower'):
end_points['source_task_logits'], end_points[
'source_quaternion'] = pixelda_task_towers.add_task_specific_model(
source_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='source_task_classifier',
reuse_shared=False)
end_points['transferred_task_logits'], end_points[
'transferred_quaternion'] = (
pixelda_task_towers.add_task_specific_model(
end_points['transferred_images'],
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='transferred_task_classifier',
reuse_shared=True))
end_points['target_task_logits'], end_points[
'target_quaternion'] = pixelda_task_towers.add_task_specific_model(
target_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=True,
private_scope='transferred_task_classifier',
reuse_shared=True)
# Remove any endpoints with None values
return dict((k, v) for k, v in end_points.iteritems() if v is not None)
def batch_norm_params(is_training, batch_norm_decay):
return {
'is_training': is_training,
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
}
def lrelu(x, leakiness=0.2):
"""Relu, with optional leaky support."""
return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
def upsample(net, num_filters, scale=2, method='resize_conv', scope=None):
"""Performs spatial upsampling of the given features.
Args:
net: A `Tensor` of shape [batch_size, height, width, filters].
num_filters: The number of output filters.
scale: The scale of the upsampling. Must be a positive integer greater or
equal to two.
method: The method by which the features are upsampled. Valid options
include 'resize_conv' and 'conv2d_transpose'.
scope: An optional variable scope.
Returns:
A new set of features of shape
[batch_size, height*scale, width*scale, num_filters].
Raises:
ValueError: if `method` is not valid or
"""
if scale < 2:
raise ValueError('scale must be greater or equal to two.')
with tf.variable_scope(scope, 'upsample', [net]):
if method == 'resize_conv':
net = tf.image.resize_nearest_neighbor(
net, [net.shape.as_list()[1] * scale,
net.shape.as_list()[2] * scale],
align_corners=True,
name='resize')
return slim.conv2d(net, num_filters, stride=1, scope='conv')
elif method == 'conv2d_transpose':
return slim.conv2d_transpose(net, num_filters, scope='deconv')
else:
raise ValueError('Upsample method [%s] was not recognized.' % method)
def project_latent_vars(hparams, proj_shape, latent_vars, combine_method='sum'):
"""Generate noise and project to input volume size.
Args:
hparams: The hyperparameter HParams struct.
proj_shape: Shape to project noise (not including batch size).
latent_vars: dictionary of `'key': Tensor of shape [batch_size, N]`
combine_method: How to combine the projected values.
sum = project to volume then sum
concat = concatenate along last dimension (i.e. channel)
Returns:
If combine_method=sum, a `Tensor` of size `hparams.projection_shape`
If combine_method=concat and there are N latent vars, a `Tensor` of size
`hparams.projection_shape`, with the last channel multiplied by N
Raises:
ValueError: combine_method is not one of sum/concat
"""
values = []
for var in latent_vars:
with tf.variable_scope(var):
# Project & reshape noise to a HxWxC input
projected = slim.fully_connected(
latent_vars[var],
np.prod(proj_shape),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm)
values.append(tf.reshape(projected, [hparams.batch_size] + proj_shape))
if combine_method == 'sum':
result = values[0]
for value in values[1:]:
result += value
elif combine_method == 'concat':
# Concatenate along last axis
result = tf.concat(values, len(proj_shape))
else:
raise ValueError('Unknown combine_method %s' % combine_method)
tf.logging.info('Latent variables projected to size %s volume', result.shape)
return result
def resnet_block(net, hparams):
"""Create a resnet block."""
net_in = net
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=tf.nn.relu)
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=None)
if hparams.resnet_residuals:
net += net_in
return net
def resnet_stack(images, output_shape, hparams, scope=None):
"""Create a resnet style transfer block.
Args:
images: [batch-size, height, width, channels] image tensor to feed as input
output_shape: output image shape in form [height, width, channels]
hparams: hparams objects
scope: Variable scope
Returns:
Images after processing with resnet blocks.
"""
end_points = {}
if hparams.noise_channel:
# separate the noise for visualization
end_points['noise'] = images[:, :, :, -1]
assert images.shape.as_list()[1:3] == output_shape[0:2]
with tf.variable_scope(scope, 'resnet_style_transfer', [images]):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
kernel_size=[hparams.generator_kernel_size] * 2,
stride=1):
net = slim.conv2d(
images,
hparams.resnet_filters,
normalizer_fn=None,
activation_fn=tf.nn.relu)
for block in range(hparams.resnet_blocks):
net = resnet_block(net, hparams)
end_points['resnet_block_{}'.format(block)] = net
net = slim.conv2d(
net,
output_shape[-1],
kernel_size=[1, 1],
normalizer_fn=None,
activation_fn=tf.nn.tanh,
scope='conv_out')
end_points['transferred_images'] = net
return net, end_points
def predict_domain(images,
hparams,
is_training=False,
reuse=False,
scope='discriminator'):
"""Creates a discriminator for a GAN.
Args:
images: A `Tensor` of size [batch_size, height, width, channels]. It is
assumed that the images are centered between -1 and 1.
hparams: hparam object with params for discriminator
is_training: Specifies whether or not we're training or testing.
reuse: Whether to reuse variable scope
scope: An optional variable_scope.
Returns:
[batch size, 1] - logit output of discriminator.
"""
with tf.variable_scope(scope, 'discriminator', [images], reuse=reuse):
lrelu_partial = functools.partial(lrelu, leakiness=hparams.lrelu_leakiness)
with slim.arg_scope(
[slim.conv2d],
kernel_size=[hparams.discriminator_kernel_size] * 2,
activation_fn=lrelu_partial,
stride=2,
normalizer_fn=slim.batch_norm):
def add_noise(hidden, scope_num=None):
if scope_num:
hidden = slim.dropout(
hidden,
hparams.discriminator_dropout_keep_prob,
is_training=is_training,
scope='dropout_%s' % scope_num)
if hparams.discriminator_noise_stddev == 0:
return hidden
return hidden + tf.random_normal(
hidden.shape.as_list(),
mean=0.0,
stddev=hparams.discriminator_noise_stddev)
# As per the recommendation of the DCGAN paper, we don't use batch norm
# on the discriminator input (https://arxiv.org/pdf/1511.06434v2.pdf).
if hparams.discriminator_image_noise:
images = add_noise(images)
net = slim.conv2d(
images,
hparams.num_discriminator_filters,
normalizer_fn=None,
stride=hparams.discriminator_first_stride,
scope='conv1_stride%s' % hparams.discriminator_first_stride)
net = add_noise(net, 1)
block_id = 2
# Repeatedly stack
# discriminator_conv_block_size-1 conv layers with stride 1
# followed by a stride 2 layer
# Add (optional) noise at every point
while net.shape.as_list()[1] > hparams.projection_shape_size:
num_filters = int(hparams.num_discriminator_filters *
(hparams.discriminator_filter_factor**(block_id - 1)))
for conv_id in range(1, hparams.discriminator_conv_block_size):
net = slim.conv2d(
net,
num_filters,
stride=1,
scope='conv_%s_%s' % (block_id, conv_id))
if hparams.discriminator_do_pooling:
net = slim.conv2d(
net, num_filters, scope='conv_%s_prepool' % block_id)
net = slim.avg_pool2d(
net, kernel_size=[2, 2], stride=2, scope='pool_%s' % block_id)
else:
net = slim.conv2d(
net, num_filters, scope='conv_%s_stride2' % block_id)
net = add_noise(net, block_id)
block_id += 1
net = slim.flatten(net)
net = slim.fully_connected(
net,
1,
# Models with BN here generally produce noise
normalizer_fn=None,
activation_fn=None,
scope='fc_logit_out') # Returns logits!
return net
def dcgan_generator(images, output_shape, hparams, scope=None):
"""Transforms the visual style of the input images.
Args:
images: A `Tensor` of shape [batch_size, height, width, channels].
output_shape: A list or tuple of 3 elements: the output height, width and
number of channels.
hparams: hparams object with generator parameters
scope: Scope to place generator inside
Returns:
A `Tensor` of shape [batch_size, height, width, output_channels] which
represents the result of style transfer.
Raises:
ValueError: If `output_shape` is not a list or tuple or if it doesn't have
three elements or if `output_shape` or `images` arent square.
"""
if not isinstance(output_shape, (tuple, list)):
raise ValueError('output_shape must be a tuple or list.')
elif len(output_shape) != 3:
raise ValueError('output_shape must have three elements.')
if output_shape[0] != output_shape[1]:
raise ValueError('output_shape must be square')
if images.shape.as_list()[1] != images.shape.as_list()[2]:
raise ValueError('images height and width must match.')
outdim = output_shape[0]
indim = images.shape.as_list()[1]
num_iterations = int(math.ceil(math.log(float(outdim) / float(indim), 2.0)))
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose],
kernel_size=[hparams.generator_kernel_size] * 2,
stride=2):
with tf.variable_scope(scope or 'generator'):
net = images
# Repeatedly halve # filters until = hparams.decode_filters in last layer
for i in range(num_iterations):
num_filters = hparams.num_decoder_filters * 2**(num_iterations - i - 1)
net = slim.conv2d_transpose(net, num_filters, scope='deconv_%s' % i)
# Crop down to desired size (e.g. 32x32 -> 28x28)
dif = net.shape.as_list()[1] - outdim
low = dif / 2
high = net.shape.as_list()[1] - low
net = net[:, low:high, low:high, :]
# No batch norm on generator output
net = slim.conv2d(
net,
output_shape[2],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
return net
def dcgan(target_images, latent_vars, hparams, scope='dcgan'):
"""Creates the PixelDA model.
Args:
target_images: A `Tensor` of shape [batch_size, height, width, 3]
sampled from the image domain to which we want to transfer.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
hparams: The hyperparameter map.
scope: Surround generator component with this scope
Returns:
A dictionary of model outputs.
"""
proj_shape = [
hparams.projection_shape_size, hparams.projection_shape_size,
hparams.projection_shape_channels
]
source_volume = project_latent_vars(
hparams, proj_shape, latent_vars, combine_method='concat')
###################################################
# Transfer the source images to the target style. #
###################################################
with tf.variable_scope(scope, 'generator', [target_images]):
transferred_images = dcgan_generator(
source_volume,
output_shape=target_images.shape.as_list()[1:4],
hparams=hparams)
assert transferred_images.shape.as_list() == target_images.shape.as_list()
return {'transferred_images': transferred_images}
def resnet_generator(images, output_shape, hparams, latent_vars=None):
"""Creates a ResNet-based generator.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer
output_shape: A length-3 array indicating the height, width and channels of
the output.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
with tf.variable_scope('generator'):
if latent_vars:
noise_channel = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
images = tf.concat([images, noise_channel], 3)
transferred_images, end_points = resnet_stack(
images,
output_shape=output_shape,
hparams=hparams,
scope='resnet_stack')
end_points['transferred_images'] = transferred_images
return end_points
def residual_interpretation_block(images, hparams, scope):
"""Learns a residual image which is added to the incoming image.
Args:
images: A `Tensor` of size [batch_size, height, width, 3]
hparams: The hyperparameters struct.
scope: The name of the variable op scope.
Returns:
The updated images.
"""
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=None,
kernel_size=[hparams.generator_kernel_size] * 2):
net = images
for _ in range(hparams.res_int_convs):
net = slim.conv2d(
net, hparams.res_int_filters, activation_fn=tf.nn.relu)
net = slim.conv2d(net, 3, activation_fn=tf.nn.tanh)
# Add the residual
images += net
# Clip the output
images = tf.maximum(images, -1.0)
images = tf.minimum(images, 1.0)
return images
def residual_interpretation_generator(images,
is_training,
hparams,
latent_vars=None):
"""Creates a generator producing purely residual transformations.
A residual generator differs from the resnet generator in that each 'block' of
the residual generator produces a residual image. Consequently, the 'progress'
of the model generation process can be directly observed at inference time,
making it easier to diagnose and understand.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer. It is
assumed that the images are centered between -1 and 1.
is_training: whether or not the model is training.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
end_points = {}
with tf.variable_scope('generator'):
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [images.shape.as_list()[-1]],
latent_vars=latent_vars,
combine_method='sum')
images += projected_latent
with tf.variable_scope(None, 'residual_style_transfer', [images]):
for i in range(hparams.res_int_blocks):
images = residual_interpretation_block(images, hparams,
'residual_%d' % i)
end_points['transferred_images_%d' % i] = images
end_points['transferred_images'] = images
return end_points
def simple_generator(source_images, target_images, is_training, hparams,
latent_vars):
"""Simple generator architecture (stack of convs) for trying small models."""
end_points = {}
with tf.variable_scope('generator'):
feed_source_images = source_images
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=source_images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
feed_source_images = tf.concat([source_images, projected_latent], 3)
end_points = {}
###################################################
# Transfer the source images to the target style. #
###################################################
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
stride=1,
kernel_size=[hparams.generator_kernel_size] * 2):
net = feed_source_images
# N convolutions
for i in range(1, hparams.simple_num_conv_layers):
normalizer_fn = None
if i != 0:
normalizer_fn = slim.batch_norm
net = slim.conv2d(
net,
hparams.simple_conv_filters,
normalizer_fn=normalizer_fn,
activation_fn=tf.nn.relu)
# Project back to right # image channels
net = slim.conv2d(
net,
target_images.shape.as_list()[-1],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
transferred_images = net
assert transferred_images.shape.as_list() == target_images.shape.as_list()
end_points['transferred_images'] = transferred_images
return end_points
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for preprocessing the inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
def preprocess_classification(image, labels, is_training=False):
"""Preprocesses the image and labels for classification purposes.
Preprocessing includes shifting the images to be 0-centered between -1 and 1.
This is not only a popular method of preprocessing (inception) but is also
the mechanism used by DSNs.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
is_training: Whether or not we're training the model.
Returns:
The preprocessed image and labels.
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
return image, labels
def preprocess_style_transfer(image,
labels,
augment=False,
size=None,
is_training=False):
"""Preprocesses the image and labels for style transfer purposes.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
augment: Whether to apply data augmentation to inputs
size: The height and width to which images should be resized. If left as
`None`, then no resizing is performed
is_training: Whether or not we're training the model
Returns:
The preprocessed image and labels. Scaled to [-1, 1]
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
if augment and is_training:
image = image_augmentation(image)
if size:
image = resize_image(image, size)
image -= 0.5
image *= 2
return image, labels
def image_augmentation(image):
"""Performs data augmentation by randomly permuting the inputs.
Args:
image: A float `Tensor` of size [height, width, channels] with values
in range[0,1].
Returns:
The mutated batch of images
"""
# Apply photometric data augmentation (contrast etc.)
num_channels = image.shape_as_list()[-1]
if num_channels == 4:
# Only augment image part
image, depth = image[:, :, 0:3], image[:, :, 3:4]
elif num_channels == 1:
image = tf.image.grayscale_to_rgb(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.clip_by_value(image, 0, 1.0)
if num_channels == 4:
image = tf.concat(2, [image, depth])
elif num_channels == 1:
image = tf.image.rgb_to_grayscale(image)
return image
def resize_image(image, size=None):
"""Resize image to target size.
Args:
image: A `Tensor` of size [height, width, 3].
size: (height, width) to resize image to.
Returns:
resized image
"""
if size is None:
raise ValueError('Must specify size')
if image.shape_as_list()[:2] == size:
# Don't resize if not necessary
return image
image = tf.expand_dims(image, 0)
image = tf.image.resize_images(image, size)
image = tf.squeeze(image, 0)
return image
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
class PixelDAPreprocessTest(tf.test.TestCase):
def assert_preprocess_classification_is_centered(self, dtype, is_training):
tf.set_random_seed(0)
if dtype == tf.uint8:
image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
image = tf.cast(image, tf.uint8)
else:
image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
labels = {}
image, labels = pixelda_preprocess.preprocess_classification(
image, labels, is_training=is_training)
with self.test_session() as sess:
np_image = sess.run(image)
self.assertTrue(np_image.min() <= -0.95)
self.assertTrue(np_image.min() >= -1.0)
self.assertTrue(np_image.max() >= 0.95)
self.assertTrue(np_image.max() <= 1.0)
def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=True)
def testPreprocessClassificationZeroCentersUint8DuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=False)
def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=True)
def testPreprocessClassificationZeroCentersFloatDuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task towers for PixelDA model."""
import tensorflow as tf
slim = tf.contrib.slim
def add_task_specific_model(images,
hparams,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope=None):
"""Create a classifier for the given images.
The classifier is composed of a few 'private' layers followed by a few
'shared' layers. This lets us account for different image 'style', while
sharing the last few layers as 'content' layers.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
hparams: model hparams
num_classes: The number of output classes.
is_training: whether model is training
reuse_private: Whether or not to reuse the private weights, which are the
first few layers in the classifier
private_scope: The name of the variable_scope for the private (unshared)
components of the classifier.
reuse_shared: Whether or not to reuse the shared weights, which are the last
few layers in the classifier
shared_scope: The name of the variable_scope for the shared components of
the classifier.
Returns:
The logits, a `Tensor` of shape [batch_size, num_classes].
Raises:
ValueError: If hparams.task_classifier is an unknown value
"""
model = hparams.task_tower
# Make sure the classifier name shows up in graph
shared_scope = shared_scope or (model + '_shared')
kwargs = {
'num_classes': num_classes,
'is_training': is_training,
'reuse_private': reuse_private,
'reuse_shared': reuse_shared,
}
if private_scope:
kwargs['private_scope'] = private_scope
if shared_scope:
kwargs['shared_scope'] = shared_scope
quaternion_pred = None
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay_task_classifier)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if model == 'doubling_pose_estimator':
logits, quaternion_pred = doubling_cnn_class_and_quaternion(
images, num_private_layers=hparams.num_private_layers, **kwargs)
elif model == 'mnist':
logits, _ = mnist_classifier(images, **kwargs)
elif model == 'svhn':
logits, _ = svhn_classifier(images, **kwargs)
elif model == 'gtsrb':
logits, _ = gtsrb_classifier(images, **kwargs)
elif model == 'pose_mini':
logits, quaternion_pred = pose_mini_tower(images, **kwargs)
else:
raise ValueError('Unknown task classifier %s' % model)
return logits, quaternion_pred
#####################################
# Classifiers used in the DSN paper #
#####################################
def mnist_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope='mnist',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional MNIST model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits, endpoints = conv_mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool2']), 100, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def svhn_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional SVHN model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['conv3']), 3072, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def gtsrb_classifier(images,
is_training=False,
num_classes=43,
reuse_private=False,
private_scope='gtsrb',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional GTSRB model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
reuse_private: Whether or not to reuse the private components of the model.
private_scope: The name of the private scope.
reuse_shared: Whether or not to reuse the shared components of the model.
shared_scope: The name of the shared scope.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
net['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, net
#########################
# pose_mini task towers #
#########################
def pose_mini_tower(images,
num_classes=11,
is_training=False,
reuse_private=False,
private_scope='pose_mini',
reuse_shared=False,
shared_scope='task_model'):
"""Task tower for the pose_mini dataset."""
with tf.variable_scope(private_scope, reuse=reuse_private):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
return logits, quaternion_pred
def doubling_cnn_class_and_quaternion(images,
num_private_layers=1,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope='doubling_cnn',
reuse_shared=False,
shared_scope='task_model'):
"""Alternate conv, pool while doubling filter count."""
net = images
depth = 32
layer_id = 1
with tf.variable_scope(private_scope, reuse=reuse_private):
while num_private_layers > 0 and net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
num_private_layers -= 1
with tf.variable_scope(shared_scope, reuse=reuse_shared):
while net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
net = slim.flatten(net)
net = slim.fully_connected(net, 100, scope='fc1')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc_logits')
return logits, quaternion_pred
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the PixelDA model."""
from functools import partial
import os
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_string('train_log_dir', '/tmp/pixelda/',
'Directory where to write event logs.')
flags.DEFINE_integer(
'save_summaries_steps', 500,
'The frequency with which summaries are saved, in seconds.')
flags.DEFINE_integer('save_interval_secs', 300,
'The frequency with which the model is saved, in seconds.')
flags.DEFINE_boolean('summarize_gradients', False,
'Whether to summarize model gradients')
flags.DEFINE_integer(
'print_loss_steps', 100,
'The frequency with which the losses are printed, in steps.')
flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.'
' If hparams="arch=dcgan", this flag is ignored.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string('source_split_name', 'train',
'Name of the train split for the source.')
flags.DEFINE_string('target_split_name', 'train',
'Name of the train split for the target.')
flags.DEFINE_string('dataset_dir', '',
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def _get_vars_and_update_ops(hparams, scope):
"""Returns the variables and update ops for a particular variable scope.
Args:
hparams: The hyperparameters struct.
scope: The variable scope.
Returns:
A tuple consisting of trainable variables and update ops.
"""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = filter(is_trainable, slim.get_model_variables(scope))
global_step = slim.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
tf.logging.info('All variables for scope: %s',
slim.get_model_variables(scope))
tf.logging.info('Trainable variables for scope: %s', var_list)
return var_list, update_ops
def _train(discriminator_train_op,
generator_train_op,
logdir,
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
hparams=None):
"""Runs the training loop.
Args:
discriminator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the discriminator.
generator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the generator.
logdir: The directory where the graph and checkpoints are saved.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
hparams: The hparams struct.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
`save_summaries_steps` are `None.
"""
global_step = slim.get_or_create_global_step()
scaffold = scaffold or tf.train.Scaffold()
hooks = hooks or []
if is_chief:
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, checkpoint_dir=logdir, master=master)
if chief_only_hooks:
hooks.extend(chief_only_hooks)
hooks.append(tf.train.StepCounterHook(output_dir=logdir))
if save_summaries_steps:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_summaries_steps is None')
hooks.append(
tf.train.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=logdir))
if save_checkpoint_secs:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_checkpoint_secs is None')
hooks.append(
tf.train.CheckpointSaverHook(
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
else:
session_creator = tf.train.WorkerSessionCreator(
scaffold=scaffold, master=master)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
loss = None
while not session.should_stop():
# Run the domain classifier op X times.
for _ in range(hparams.discriminator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run(
[discriminator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step,
loss)
# Run the generator op X times.
for _ in range(hparams.generator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run([generator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step,
loss)
return loss
def run_training(run_dir, checkpoint_dir, hparams):
"""Runs the training loop.
Args:
run_dir: The directory where training specific logs are placed
checkpoint_dir: The directory where the checkpoints and log files are
stored.
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for path in [run_dir, checkpoint_dir]:
if not tf.gfile.Exists(path):
tf.gfile.MakeDirs(path)
# Serialize hparams to log dir
hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
with tf.gfile.FastGFile(hparams_filename, 'w') as f:
f.write(hparams.to_json())
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
target_images, _ = dataset_factory.provide_batch(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
# Data provider provides 1 hot labels, but we expect categorical.
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Source and Target datasets must have same number of classes. '
'Are %d and %d' % (num_source_classes, num_target_classes))
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=True,
num_classes=num_target_classes)
#################################
# Get the variables to optimize #
#################################
generator_vars, generator_update_ops = _get_vars_and_update_ops(
hparams, 'generator')
discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
hparams, 'discriminator')
########################
# Configure the losses #
########################
generator_loss = pixelda_losses.g_step_loss(
source_images,
source_labels,
end_points,
hparams,
num_classes=num_target_classes)
discriminator_loss = pixelda_losses.d_step_loss(
end_points, source_labels, num_target_classes, hparams)
###########################
# Create the training ops #
###########################
learning_rate = hparams.learning_rate
if hparams.lr_decay_steps:
learning_rate = tf.train.exponential_decay(
learning_rate,
slim.get_or_create_global_step(),
decay_steps=hparams.lr_decay_steps,
decay_rate=hparams.lr_decay_rate,
staircase=True)
tf.summary.scalar('Learning_rate', learning_rate)
if hparams.discriminator_steps == 0:
discriminator_train_op = tf.no_op()
else:
discriminator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
discriminator_train_op = slim.learning.create_train_op(
discriminator_loss,
discriminator_optimizer,
update_ops=discriminator_update_ops,
variables_to_train=discriminator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
if hparams.generator_steps == 0:
generator_train_op = tf.no_op()
else:
generator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
generator_train_op = slim.learning.create_train_op(
generator_loss,
generator_optimizer,
update_ops=generator_update_ops,
variables_to_train=generator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
#############
# Summaries #
#############
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summaries_color_distributions(end_points['transferred_images'],
'Transferred')
pixelda_utils.summaries_color_distributions(target_images, 'Target')
if source_images is not None:
pixelda_utils.summarize_transferred(source_images,
end_points['transferred_images'])
pixelda_utils.summaries_color_distributions(source_images, 'Source')
pixelda_utils.summaries_color_distributions(
tf.abs(source_images - end_points['transferred_images']),
'Abs(Source_minus_Transferred)')
number_of_steps = None
if hparams.num_training_examples:
# Want to control by amount of data seen, not # steps
number_of_steps = hparams.num_training_examples / hparams.batch_size
hooks = [tf.train.StepCounterHook(),]
chief_only_hooks = [
tf.train.CheckpointSaverHook(
saver=tf.train.Saver(),
checkpoint_dir=run_dir,
save_secs=FLAGS.save_interval_secs)
]
if number_of_steps:
hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps))
_train(
discriminator_train_op,
generator_train_op,
logdir=run_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=None,
save_summaries_steps=FLAGS.save_summaries_steps,
hparams=hparams)
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_training(
run_dir=FLAGS.train_log_dir,
checkpoint_dir=FLAGS.train_log_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for PixelDA model."""
import math
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
def remove_depth(images):
"""Takes a batch of images and remove depth channel if present."""
if images.shape.as_list()[-1] == 4:
return images[:, :, :, 0:3]
return images
def image_grid(images, max_grid_size=4):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
images = remove_depth(images)
batch_size = images.shape.as_list()[0]
grid_size = min(int(math.sqrt(batch_size)), max_grid_size)
assert images.shape.as_list()[0] >= grid_size * grid_size
# If we have a depth channel
if images.shape.as_list()[-1] == 4:
images = images[:grid_size * grid_size, :, :, 0:3]
depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
split = tf.split(0, grid_size, images)
depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
depth_split = tf.split(0, grid_size, depth)
grid = tf.concat(split + depth_split, 1)
return tf.expand_dims(grid, 0)
else:
images = images[:grid_size * grid_size, :, :, :]
images = tf.reshape(
images, [-1, images.shape.as_list()[2],
images.shape.as_list()[3]])
split = tf.split(images, grid_size, 0)
grid = tf.concat(split, 1)
return tf.expand_dims(grid, 0)
def source_and_output_image_grid(output_images,
source_images=None,
max_grid_size=4):
"""Create NxN image grid for output, concatenate source grid if given.
Makes grid out of output_images and, if provided, source_images, and
concatenates them.
Args:
output_images: [batch_size, h, w, c] tensor of images
source_images: optional[batch_size, h, w, c] tensor of images
max_grid_size: Image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
output_grid = image_grid(output_images, max_grid_size=max_grid_size)
if source_images is not None:
source_grid = image_grid(source_images, max_grid_size=max_grid_size)
# Make sure they have the same # of channels before concat
# Assumes either 1 or 3 channels
if output_grid.shape.as_list()[-1] != source_grid.shape.as_list()[-1]:
if output_grid.shape.as_list()[-1] == 1:
output_grid = tf.tile(output_grid, [1, 1, 1, 3])
if source_grid.shape.as_list()[-1] == 1:
source_grid = tf.tile(source_grid, [1, 1, 1, 3])
output_grid = tf.concat([output_grid, source_grid], 1)
return output_grid
def summarize_model(end_points):
"""Summarizes the given model via its end_points.
Args:
end_points: A dictionary of end_point names to `Tensor`.
"""
tf.summary.histogram('domain_logits_transferred',
tf.sigmoid(end_points['transferred_domain_logits']))
tf.summary.histogram('domain_logits_target',
tf.sigmoid(end_points['target_domain_logits']))
def summarize_transferred_grid(transferred_images,
source_images=None,
name='Transferred'):
"""Produces a visual grid summarization of the image transferrence.
Args:
transferred_images: A `Tensor` of size [batch_size, height, width, c].
source_images: A `Tensor` of size [batch_size, height, width, c].
name: Name to use in summary name
"""
if source_images is not None:
grid = source_and_output_image_grid(transferred_images, source_images)
else:
grid = image_grid(transferred_images)
tf.summary.image('%s_Images_Grid' % name, grid, max_outputs=1)
def summarize_transferred(source_images,
transferred_images,
max_images=20,
name='Transferred'):
"""Produces a visual summary of the image transferrence.
This summary displays the source image, transferred image, and a grayscale
difference image which highlights the differences between input and output.
Args:
source_images: A `Tensor` of size [batch_size, height, width, channels].
transferred_images: A `Tensor` of size [batch_size, height, width, channels]
max_images: The number of images to show.
name: Name to use in summary name
Raises:
ValueError: If number of channels in source and target are incompatible
"""
source_channels = source_images.shape.as_list()[-1]
transferred_channels = transferred_images.shape.as_list()[-1]
if source_channels < transferred_channels:
if source_channels != 1:
raise ValueError(
'Source must be 1 channel or same # of channels as target')
source_images = tf.tile(source_images, [1, 1, 1, transferred_channels])
if transferred_channels < source_channels:
if transferred_channels != 1:
raise ValueError(
'Target must be 1 channel or same # of channels as source')
transferred_images = tf.tile(transferred_images, [1, 1, 1, source_channels])
diffs = tf.abs(source_images - transferred_images)
diffs = tf.reduce_max(diffs, reduction_indices=[3], keep_dims=True)
diffs = tf.tile(diffs, [1, 1, 1, max(source_channels, transferred_channels)])
transition_images = tf.concat([
source_images,
transferred_images,
diffs,
], 2)
tf.summary.image(
'%s_difference' % name, transition_images, max_outputs=max_images)
def summaries_color_distributions(images, name):
"""Produces a histogram of the color distributions of the images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
tf.summary.histogram('color_values/%s' % name, images)
def summarize_images(images, name):
"""Produces a visual summary of the given images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
grid = image_grid(images)
tf.summary.image('%s_Images' % name, grid, max_outputs=1)
...@@ -322,7 +322,7 @@ bazel-bin/im2txt/run_inference \ ...@@ -322,7 +322,7 @@ bazel-bin/im2txt/run_inference \
Example output: Example output:
```shell ```
Captions for image COCO_val2014_000000224477.jpg: Captions for image COCO_val2014_000000224477.jpg:
0) a man riding a wave on top of a surfboard . (p=0.040413) 0) a man riding a wave on top of a surfboard . (p=0.040413)
1) a person riding a surf board on a wave (p=0.017452) 1) a person riding a surf board on a wave (p=0.017452)
......
...@@ -7,9 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v ...@@ -7,9 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need: The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.1 ([install](http://tflearn.org/installation/)) - * **TensorFlow** version 1.2.1 ([install](https://www.tensorflow.org/install/)) -
there is an incompatibility with LFADS and TF v1.2, which we are in the
process of resolving
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them) * **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py)) * **h5py** ([install](https://pypi.python.org/pypi/h5py))
...@@ -38,10 +36,10 @@ These synthetic datasets are provided 1. to gain insight into how the LFADS algo ...@@ -38,10 +36,10 @@ These synthetic datasets are provided 1. to gain insight into how the LFADS algo
## Train an LFADS model ## Train an LFADS model
Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 8) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters. Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 9) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
```sh ```sh
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) # Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with spiking noise
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \ --data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \ --data_filename_stem=chaotic_rnn_no_inputs \
...@@ -108,14 +106,16 @@ $ python run_lfads.py --kind=train \ ...@@ -108,14 +106,16 @@ $ python run_lfads.py --kind=train \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \ --data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \ --lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \ --co_dim=1 \
--factors_dim=20 --factors_dim=20 \
--output_dist=poisson
# Run LFADS on multi-session RNN data # Run LFADS on multi-session RNN data
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \ --data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_multisession \ --data_filename_stem=chaotic_rnn_multisession \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \ --lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
--factors_dim=10 --factors_dim=10 \
--output_dist=poisson
# Run LFADS on integration to bound model data # Run LFADS on integration to bound model data
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
...@@ -124,7 +124,8 @@ $ python run_lfads.py --kind=train \ ...@@ -124,7 +124,8 @@ $ python run_lfads.py --kind=train \
--lfads_save_dir=/tmp/lfads_itb_rnn \ --lfads_save_dir=/tmp/lfads_itb_rnn \
--co_dim=1 \ --co_dim=1 \
--factors_dim=20 \ --factors_dim=20 \
--controller_input_lag=0 --controller_input_lag=0 \
--output_dist=poisson
# Run LFADS on chaotic RNN data with labels # Run LFADS on chaotic RNN data with labels
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
...@@ -134,7 +135,20 @@ $ python run_lfads.py --kind=train \ ...@@ -134,7 +135,20 @@ $ python run_lfads.py --kind=train \
--co_dim=0 \ --co_dim=0 \
--factors_dim=20 \ --factors_dim=20 \
--controller_input_lag=0 \ --controller_input_lag=0 \
--ext_input_dim=1 --ext_input_dim=1 \
--output_dist=poisson
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=gaussian \
``` ```
......
...@@ -43,6 +43,10 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys: ...@@ -43,6 +43,10 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys:
output adapter for each dataset. These matrices, if provided, must be of output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag. on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials, If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then (say simulated, testing data, as in the example shipped with the paper), then
...@@ -277,7 +281,7 @@ class LFADS(object): ...@@ -277,7 +281,7 @@ class LFADS(object):
"""Create an LFADS model. """Create an LFADS model.
train - a model for training, sampling of posteriors is used train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative posterior. Also used for the lower bound on the negative
...@@ -356,18 +360,36 @@ class LFADS(object): ...@@ -356,18 +360,36 @@ class LFADS(object):
for d, name in enumerate(dataset_names): for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name] data_dim = hps.dataset_dims[name]
in_mat_cxf = None in_mat_cxf = None
in_bias_1xf = None
align_bias_1xc = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name) print("Using alignment matrix provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""% (data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0], (data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1])) in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment bias provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
raise ValueError("""Alignment bias must have dimensions %d
(data_dim), but currently has %d."""%
(data_dim, in_mat_cxf.shape[0]))
if in_mat_cxf is not None and align_bias_1xc is not None:
# (data - alignment_bias) * W_in
# data * W_in - alignment_bias * W_in
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True, in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True,
mat_init_value=in_mat_cxf, mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss, identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name, normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
...@@ -387,13 +409,22 @@ class LFADS(object): ...@@ -387,13 +409,22 @@ class LFADS(object):
dataset = datasets[name] dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
out_mat_cxf = None if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
out_mat_cxf = in_mat_cxf.T out_mat_fxc = np.linalg.pinv(in_mat_cxf)
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
if hps.output_dist == 'poisson': if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True, out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
identity_if_possible=out_identity_if_poss, identity_if_possible=out_identity_if_poss,
normalized=False, normalized=False,
name="fac_2_logrates_"+name, name="fac_2_logrates_"+name,
...@@ -403,13 +434,19 @@ class LFADS(object): ...@@ -403,13 +434,19 @@ class LFADS(object):
elif hps.output_dist == 'gaussian': elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \ out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True, init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
normalized=False, normalized=False,
name="fac_2_means_"+name, name="fac_2_means_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
bias_init_value = np.ones([1, data_dim]).astype(np.float32)
out_fac_lin_logvar = \ out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True, init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=mat_init_value,
bias_init_value=bias_init_value,
normalized=False, normalized=False,
name="fac_2_logvars_"+name, name="fac_2_logvars_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
...@@ -432,11 +469,15 @@ class LFADS(object): ...@@ -432,11 +469,15 @@ class LFADS(object):
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
case_default = lambda: tf.constant([-8675309.0]) def _case_with_no_default(pairs):
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, case_default, exclusive=True) def _default_value_fn():
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, case_default, exclusive=True) with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, case_default, exclusive=True) return tf.identity(pairs[0][1]())
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, case_default, exclusive=True) return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = _case_with_no_default(pf_pairs_in_fac_Ws)
this_in_fac_b = _case_with_no_default(pf_pairs_in_fac_bs)
this_out_fac_W = _case_with_no_default(pf_pairs_out_fac_Ws)
this_out_fac_b = _case_with_no_default(pf_pairs_out_fac_bs)
# External inputs (not changing by dataset, by definition). # External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0: if hps.ext_input_dim > 0:
...@@ -952,7 +993,7 @@ class LFADS(object): ...@@ -952,7 +993,7 @@ class LFADS(object):
session = tf.get_default_session() session = tf.get_default_session()
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log") self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
self.writer = tf.summary.FileWriter(self.logfile, session.graph) self.writer = tf.summary.FileWriter(self.logfile)
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None, def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
keep_prob=None): keep_prob=None):
...@@ -1678,7 +1719,7 @@ class LFADS(object): ...@@ -1678,7 +1719,7 @@ class LFADS(object):
out_dist_params = np.zeros([E_to_process, T, D+D]) out_dist_params = np.zeros([E_to_process, T, D+D])
else: else:
assert False, "NIY" assert False, "NIY"
costs = np.zeros(E_to_process) costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process) nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process)
...@@ -1878,7 +1919,7 @@ class LFADS(object): ...@@ -1878,7 +1919,7 @@ class LFADS(object):
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name): if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval)) with value: """, var_eval, type(var_eval))
e = np.array(var_eval) e = np.array(var_eval)
print(e, type(e)) print(e, type(e))
......
...@@ -24,7 +24,7 @@ from utils import write_datasets ...@@ -24,7 +24,7 @@ from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.signal import scipy.signal
...@@ -37,6 +37,7 @@ flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", ...@@ -37,6 +37,7 @@ flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.") "Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data", flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.") "Name of data file for input case.")
flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions") flags.DEFINE_integer("C", 100, "Number of conditions")
...@@ -45,8 +46,8 @@ flags.DEFINE_integer("S", 50, "Number of sampled units from RNN") ...@@ -45,8 +46,8 @@ flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.") flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0, flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials") "Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 40, flags.DEFINE_integer("nreplications", 40,
"Number of spikifications of the same underlying rates.") "Number of noise replications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics") flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0, flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.") "Volume from which to pull initial conditions (affects diversity of dynamics.")
...@@ -73,8 +74,8 @@ C = FLAGS.C ...@@ -73,8 +74,8 @@ C = FLAGS.C
N = FLAGS.N N = FLAGS.N
S = FLAGS.S S = FLAGS.S
input_magnitude = FLAGS.input_magnitude input_magnitude = FLAGS.input_magnitude
nspikifications = FLAGS.nspikifications nreplications = FLAGS.nreplications
E = nspikifications * C # total number of trials E = nreplications * C # total number of trials
# S is the number of measurements in each datasets, w/ each # S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations. # dataset having a different set of observations.
ndatasets = N/S # ok if rounded down ndatasets = N/S # ok if rounded down
...@@ -87,9 +88,9 @@ rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate) ...@@ -87,9 +88,9 @@ rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper. # Check to make sure the RNN is the one we used in the paper.
if N == 50: if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?' assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nspikifications * train_percentage rem_check = nreplications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \ assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nspikifications should be integral number.' 'Train percentage * nreplications should be integral number.'
# Initial condition generation, and condition label generation. This # Initial condition generation, and condition label generation. This
...@@ -100,9 +101,9 @@ x0s = [] ...@@ -100,9 +101,9 @@ x0s = []
condition_labels = [] condition_labels = []
for c in range(C): for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1) x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) # replicate x0 nspikifications times x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
# replicate the condition label nspikifications times # replicate the condition label nreplications times
for ns in range(nspikifications): for ns in range(nreplications):
condition_labels.append(condition_number) condition_labels.append(condition_number)
condition_number += 1 condition_number += 1
x0s = np.concatenate(x0s, axis=1) x0s = np.concatenate(x0s, axis=1)
...@@ -113,8 +114,8 @@ for n in range(ndatasets): ...@@ -113,8 +114,8 @@ for n in range(ndatasets):
print(n+1, " of ", ndatasets) print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all # First generate all firing rates. in the next loop, generate all
# spikifications this allows the random state for rate generation to be # replications this allows the random state for rate generation to be
# independent of n_spikifications. # independent of n_replications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S) dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N: if S < N:
dataset_name += '_n' + str(n+1) dataset_name += '_n' + str(n+1)
...@@ -136,17 +137,23 @@ for n in range(ndatasets): ...@@ -136,17 +137,23 @@ for n in range(ndatasets):
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude, input_magnitude=input_magnitude,
input_times=input_times) input_times=input_times)
spikes = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
# split into train and validation sets if FLAGS.noise_type == "poisson":
noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
elif FLAGS.noise_type == "gaussian":
noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
else:
raise ValueError("Only noise types supported are poisson or gaussian")
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications) nreplications)
# Split the data, inputs, labels and times into train vs. validation. # Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \ rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds) split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = \ noisy_data_train, noisy_data_valid = \
split_list_by_inds(spikes, train_inds, valid_inds) split_list_by_inds(noisy_data, train_inds, valid_inds)
input_train, inputs_valid = \ input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds) split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \ condition_labels_train, condition_labels_valid = \
...@@ -154,25 +161,25 @@ for n in range(ndatasets): ...@@ -154,25 +161,25 @@ for n in range(ndatasets):
input_times_train, input_times_valid = \ input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds) split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, spikes, and input into numpy arrays. # Turn rates, noisy_data, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train) rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid) rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train) noisy_data_train = nparray_and_transpose(noisy_data_train)
spikes_valid = nparray_and_transpose(spikes_valid) noisy_data_valid = nparray_and_transpose(noisy_data_valid)
input_train = nparray_and_transpose(input_train) input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid) inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this # Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the spike # structure, the only data that is used in LFADS are the noisy
# trains. The rest is either for printing or posterity. # data e.g. spike trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train, data = {'train_truth': rates_train,
'valid_truth': rates_valid, 'valid_truth': rates_valid,
'input_train_truth' : input_train, 'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid, 'input_valid_truth' : inputs_valid,
'train_data' : spikes_train, 'train_data' : noisy_data_train,
'valid_data' : spikes_valid, 'valid_data' : noisy_data_valid,
'train_percentage' : train_percentage, 'train_percentage' : train_percentage,
'nspikifications' : nspikifications, 'nreplications' : nreplications,
'dt' : rnn['dt'], 'dt' : rnn['dt'],
'input_magnitude' : input_magnitude, 'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train, 'input_times_train' : input_times_train,
......
...@@ -18,20 +18,23 @@ ...@@ -18,20 +18,23 @@
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/ SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5)" echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with input pulses (g=1.5)" echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs_gaussian --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
echo "Generating chaotic rnn data with input pulses (g=2.5)" echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generate the multi-session RNN data (no multi-session synth example in paper)" echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating Integration-to-bound RNN data" echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)" echo "Generating Integration-to-bound RNN data"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
...@@ -132,11 +132,10 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100): ...@@ -132,11 +132,10 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
dt: how often the data are sampled dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0 max_firing_rate: the firing rate that is associated with a value of 1.0
Returns: Returns:
spikified_data_e: a list of length b of the data represented as spikes, spikified_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process. sampled from the underlying poisson process.
""" """
spikifies_data_e = []
E = len(data_e) E = len(data_e)
spikes_e = [] spikes_e = []
for e in range(E): for e in range(E):
...@@ -152,6 +151,31 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100): ...@@ -152,6 +151,31 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
return spikes_e return spikes_e
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply gaussian noise to a continuous dataset whose values are between
0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
gauss_e: a list of length b of the data with noise.
"""
E = len(data_e)
mfr = max_firing_rate
gauss_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
gauss_e.append(noisy_data)
return gauss_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications): def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for """Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction. training and validation, based on the train fraction.
...@@ -295,6 +319,8 @@ def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None): ...@@ -295,6 +319,8 @@ def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
W_chxp, _, _, _ = \ W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T) np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp dataset['alignment_matrix_cxf'] = W_chxp
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
do_debug_plot = False do_debug_plot = False
if do_debug_plot: if do_debug_plot:
......
...@@ -82,9 +82,9 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False, ...@@ -82,9 +82,9 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
return tf.matmul(x, W) return tf.matmul(x, W)
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
identity_if_possible=False, normalized=False, bias_init_value=None, alpha=1.0, identity_if_possible=False,
name=None, collections=None): normalized=False, name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of """Linear (affine) transformation, y = x W + b, for a variety of
configurations. configurations.
...@@ -110,6 +110,9 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, ...@@ -110,6 +110,9 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
if mat_init_value is not None and mat_init_value.shape != (in_size, out_size): if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
raise ValueError( raise ValueError(
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size)) 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if bias_init_value is not None and bias_init_value.shape != (1,out_size):
raise ValueError(
'Provided bias_init_value must have shape [1,%d].'%(out_size,))
if mat_init_value is None: if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size)) stddev = alpha/np.sqrt(float(in_size))
...@@ -143,16 +146,20 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, ...@@ -143,16 +146,20 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections)
b = None
if do_bias: if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES] b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections: if collections:
b_collections += collections b_collections += collections
bname = (name + "/b") if name else "/b" bname = (name + "/b") if name else "/b"
b = tf.get_variable(bname, [1, out_size], if bias_init_value is None:
initializer=tf.zeros_initializer(), b = tf.get_variable(bname, [1, out_size],
collections=b_collections) initializer=tf.zeros_initializer(),
else: collections=b_collections)
b = None else:
b = tf.Variable(bias_init_value, name=bname,
collections=b_collections)
return (w, b) return (w, b)
......
...@@ -54,6 +54,17 @@ Extras: ...@@ -54,6 +54,17 @@ Extras:
Exporting a trained model for inference</a><br> Exporting a trained model for inference</a><br>
* <a href='g3doc/defining_your_own_model.md'> * <a href='g3doc/defining_your_own_model.md'>
Defining your own model architecture</a><br> Defining your own model architecture</a><br>
* <a href='g3doc/using_your_own_dataset.md'>
Bringing in your own dataset</a><br>
## Getting Help
Please report bugs to the tensorflow/models/ Github
[issue tracker](https://github.com/tensorflow/models/issues), prefixing the
issue name with "object_detection". To get help with issues you may encounter
using the Tensorflow Object Detection API, create a new question on
[StackOverflow](https://stackoverflow.com/) with the tags "tensorflow" and
"object-detection".
## Release information ## Release information
......
...@@ -270,6 +270,7 @@ py_library( ...@@ -270,6 +270,7 @@ py_library(
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape", "//tensorflow_models/object_detection/utils:static_shape",
], ],
) )
......
...@@ -29,6 +29,7 @@ few box predictor architectures are shared across many models. ...@@ -29,6 +29,7 @@ few box predictor architectures are shared across many models.
from abc import abstractmethod from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import shape_utils
from object_detection.utils import static_shape from object_detection.utils import static_shape
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -316,6 +317,8 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -316,6 +317,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
self._predict_instance_masks = predict_instance_masks self._predict_instance_masks = predict_instance_masks
self._mask_prediction_conv_depth = mask_prediction_conv_depth self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._predict_keypoints = predict_keypoints self._predict_keypoints = predict_keypoints
if self._predict_instance_masks:
raise ValueError('Mask prediction is unimplemented.')
if self._predict_keypoints: if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.') raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and if ((self._predict_instance_masks or self._predict_keypoints) and
...@@ -524,23 +527,21 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -524,23 +527,21 @@ class ConvolutionalBoxPredictor(BoxPredictor):
class_predictions_with_background = tf.sigmoid( class_predictions_with_background = tf.sigmoid(
class_predictions_with_background) class_predictions_with_background)
batch_size = static_shape.get_batch_size(image_features.get_shape()) combined_feature_map_shape = shape_utils.combined_static_and_dynamic_shape(
if batch_size is None: image_features)
features_height = static_shape.get_height(image_features.get_shape()) box_encodings = tf.reshape(
features_width = static_shape.get_width(image_features.get_shape()) box_encodings, tf.stack([combined_feature_map_shape[0],
flattened_predictions_size = (features_height * features_width * combined_feature_map_shape[1] *
num_predictions_per_location) combined_feature_map_shape[2] *
box_encodings = tf.reshape( num_predictions_per_location,
box_encodings, 1, self._box_code_size]))
[-1, flattened_predictions_size, 1, self._box_code_size]) class_predictions_with_background = tf.reshape(
class_predictions_with_background = tf.reshape( class_predictions_with_background,
class_predictions_with_background, tf.stack([combined_feature_map_shape[0],
[-1, flattened_predictions_size, num_class_slots]) combined_feature_map_shape[1] *
else: combined_feature_map_shape[2] *
box_encodings = tf.reshape( num_predictions_per_location,
box_encodings, [batch_size, -1, 1, self._box_code_size]) num_class_slots]))
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [batch_size, -1, num_class_slots])
return {BOX_ENCODINGS: box_encodings, return {BOX_ENCODINGS: box_encodings,
CLASS_PREDICTIONS_WITH_BACKGROUND: CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background} class_predictions_with_background}
...@@ -228,25 +228,24 @@ class DetectionModel(object): ...@@ -228,25 +228,24 @@ class DetectionModel(object):
fields.BoxListFields.keypoints] = groundtruth_keypoints_list fields.BoxListFields.keypoints] = groundtruth_keypoints_list
@abstractmethod @abstractmethod
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a foreign checkpoint into tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Loads variables from a different tensorflow graph (typically feature Returns a map of variable names to load from a checkpoint to variables in
extractor variables). This enables the model to initialize based on weights the model graph. This enables the model to initialize based on weights from
from another task. For example, the feature extractor variables from a another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of should have the same parameters as this detection model with exception of
the num_classes parameter. the num_classes parameter.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
pass pass
...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes,
change_coordinate_frame=False, change_coordinate_frame=False,
num_valid_boxes=None, num_valid_boxes=None,
masks=None, masks=None,
scope=None): scope=None,
parallel_iterations=32):
"""Multi-class version of non maximum suppression that operates on a batch. """Multi-class version of non maximum suppression that operates on a batch.
This op is similar to `multiclass_non_max_suppression` but operates on a batch This op is similar to `multiclass_non_max_suppression` but operates on a batch
...@@ -208,26 +209,28 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -208,26 +209,28 @@ def batch_multiclass_non_max_suppression(boxes,
float32 tensor containing box masks. `q` can be either number of classes float32 tensor containing box masks. `q` can be either number of classes
or 1 depending on whether a separate mask is predicted per class. or 1 depending on whether a separate mask is predicted per class.
scope: tf scope name. scope: tf scope name.
parallel_iterations: (optional) number of batch items to process in
parallel.
Returns: Returns:
A dictionary containing the following entries: 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
containing the non-max suppressed boxes. containing the non-max suppressed boxes.
'detection_scores': A [bath_size, max_detections] float32 tensor containing 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing
the scores for the boxes. the scores for the boxes.
'detection_classes': A [batch_size, max_detections] float32 tensor 'nmsed_classes': A [batch_size, max_detections] float32 tensor
containing the class for boxes. containing the class for boxes.
'num_detections': A [batchsize] float32 tensor indicating the number of 'nmsed_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box. This is set to None if input
`masks` is None.
'num_detections': A [batch_size] int32 tensor indicating the number of
valid detections per batch item. Only the top num_detections[i] entries in valid detections per batch item. Only the top num_detections[i] entries in
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
entries are zero paddings. entries are zero paddings.
'detection_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box.
Raises: Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have ValueError: if `q` in boxes.shape is not 1 or not equal to number of
a valid scores field. classes as inferred from scores.shape.
""" """
q = boxes.shape[2].value q = boxes.shape[2].value
num_classes = scores.shape[2].value num_classes = scores.shape[2].value
...@@ -235,36 +238,45 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -235,36 +238,45 @@ def batch_multiclass_non_max_suppression(boxes,
raise ValueError('third dimension of boxes must be either 1 or equal ' raise ValueError('third dimension of boxes must be either 1 or equal '
'to the third dimension of scores') 'to the third dimension of scores')
original_masks = masks
with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'): with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
per_image_boxes_list = tf.unstack(boxes) boxes_shape = boxes.shape
per_image_scores_list = tf.unstack(scores) batch_size = boxes_shape[0].value
num_valid_boxes_list = len(per_image_boxes_list) * [None] num_anchors = boxes_shape[1].value
per_image_masks_list = len(per_image_boxes_list) * [None]
if num_valid_boxes is not None: if batch_size is None:
num_valid_boxes_list = tf.unstack(num_valid_boxes) batch_size = tf.shape(boxes)[0]
if masks is not None: if num_anchors is None:
per_image_masks_list = tf.unstack(masks) num_anchors = tf.shape(boxes)[1]
# If num valid boxes aren't provided, create one and mark all boxes as
# valid.
if num_valid_boxes is None:
num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors
detection_boxes_list = [] # If masks aren't provided, create dummy masks so we can only have one copy
detection_scores_list = [] # of single_image_nms_fn and discard the dummy masks after map_fn.
detection_classes_list = [] if masks is None:
num_detections_list = [] masks_shape = tf.stack([batch_size, num_anchors, 1, 0, 0])
detection_masks_list = [] masks = tf.zeros(masks_shape)
for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
) in zip(per_image_boxes_list, per_image_scores_list, def single_image_nms_fn(args):
per_image_masks_list, num_valid_boxes_list): """Runs NMS on a single image and returns padded output."""
if num_valid_boxes is not None: (per_image_boxes, per_image_scores, per_image_masks,
per_image_boxes = tf.reshape( per_image_num_valid_boxes) = args
tf.slice(per_image_boxes, 3*[0], per_image_boxes = tf.reshape(
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4]) tf.slice(per_image_boxes, 3 * [0],
per_image_scores = tf.reshape( tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
tf.slice(per_image_scores, [0, 0], per_image_scores = tf.reshape(
tf.stack([num_valid_boxes, -1])), [-1, num_classes]) tf.slice(per_image_scores, [0, 0],
if masks is not None: tf.stack([per_image_num_valid_boxes, -1])),
per_image_masks = tf.reshape( [-1, num_classes])
tf.slice(per_image_masks, 4*[0],
tf.stack([num_valid_boxes, -1, -1, -1])), per_image_masks = tf.reshape(
[-1, q, masks.shape[3].value, masks.shape[4].value]) tf.slice(per_image_masks, 4 * [0],
tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
[-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression( nmsed_boxlist = multiclass_non_max_suppression(
per_image_boxes, per_image_boxes,
per_image_scores, per_image_scores,
...@@ -275,24 +287,26 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -275,24 +287,26 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks, masks=per_image_masks,
clip_window=clip_window, clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame) change_coordinate_frame=change_coordinate_frame)
num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes()))
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist, padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size) max_total_size)
detection_boxes_list.append(padded_boxlist.get()) num_detections = nmsed_boxlist.num_boxes()
detection_scores_list.append( nmsed_boxes = padded_boxlist.get()
padded_boxlist.get_field(fields.BoxListFields.scores)) nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
detection_classes_list.append( nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
padded_boxlist.get_field(fields.BoxListFields.classes)) nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
if masks is not None: return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
detection_masks_list.append( num_detections]
padded_boxlist.get_field(fields.BoxListFields.masks))
nms_dict = { (batch_nmsed_boxes, batch_nmsed_scores,
'detection_boxes': tf.stack(detection_boxes_list), batch_nmsed_classes, batch_nmsed_masks,
'detection_scores': tf.stack(detection_scores_list), batch_num_detections) = tf.map_fn(
'detection_classes': tf.stack(detection_classes_list), single_image_nms_fn,
'num_detections': tf.stack(num_detections_list) elems=[boxes, scores, masks, num_valid_boxes],
} dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
if masks is not None: parallel_iterations=parallel_iterations)
nms_dict['detection_masks'] = tf.stack(detection_masks_list)
return nms_dict if original_masks is None:
batch_nmsed_masks = None
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_nmsed_masks, batch_num_detections)
...@@ -496,15 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -496,15 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_scores = [[.95, .9, .85, .3]] exp_nms_scores = [[.95, .9, .85, .3]]
exp_nms_classes = [[0, 0, 1, 0]] exp_nms_classes = [[0, 0, 1, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size) boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertEqual(nms_output['num_detections'], [4]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def test_batch_multiclass_nms_with_batch_size_2(self): def test_batch_multiclass_nms_with_batch_size_2(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -524,28 +530,42 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -524,28 +530,42 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]], [0, 0, 0, 0]],
[[0, 999, 2, 1004], [[0, 999, 2, 1004],
[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1],
[0, 100, 1, 101], [0, 100, 1, 101],
[0, 0, 0, 0]]] [0, 0, 0, 0]]])
exp_nms_scores = [[.95, .9, 0, 0], exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]] [.85, .5, .3, 0]])
exp_nms_classes = [[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]] [1, 0, 0, 0]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_masks(self): def test_batch_multiclass_nms_with_masks(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -574,38 +594,126 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -574,38 +594,126 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]], [0, 0, 0, 0]],
[[0, 999, 2, 1004], [[0, 999, 2, 1004],
[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1],
[0, 100, 1, 101], [0, 100, 1, 101],
[0, 0, 0, 0]]] [0, 0, 0, 0]]])
exp_nms_scores = [[.95, .9, 0, 0], exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]] [.85, .5, .3, 0]])
exp_nms_classes = [[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]] [1, 0, 0, 0]])
exp_nms_masks = [[[[6, 7], [8, 9]], exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]], [[0, 1], [2, 3]],
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]], [[0, 0], [0, 0]]],
[[[13, 14], [15, 16]], [[[13, 14], [15, 16]],
[[8, 9], [10, 11]], [[8, 9], [10, 11]],
[[10, 11], [12, 13]], [[10, 11], [12, 13]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 2, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
masks = np.array([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), [None, 4, 4])
self.assertAllEqual(nmsed_scores.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_classes.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_masks.shape.as_list(), [None, 4, 2, 2])
self.assertEqual(num_detections.shape.as_list(), [None])
nms_dict = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections],
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) feed_dict={boxes_placeholder: boxes,
self.assertAllClose(nms_output['num_detections'], [2, 3]) scores_placeholder: scores,
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) masks_placeholder: masks})
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self): def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -656,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -656,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
boxes, scores, score_thresh, iou_thresh, num_detections) = post_processing.batch_multiclass_non_max_suppression(
max_size_per_class=max_output_size, max_total_size=max_output_size, boxes, scores, score_thresh, iou_thresh,
num_valid_boxes=num_valid_boxes, masks=masks) max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [1, 1]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment