Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Provides DeepLab model definition and helper functions.
DeepLab is a deep learning system for semantic image segmentation with
the following features:
(1) Atrous convolution to explicitly control the resolution at which
feature responses are computed within Deep Convolutional Neural Networks.
(2) Atrous spatial pyramid pooling (ASPP) to robustly segment objects at
multiple scales with filters at multiple sampling rates and effective
fields-of-views.
(3) ASPP module augmented with image-level feature and batch normalization.
(4) A simple yet effective decoder module to recover the object boundaries.
See the following papers for more details:
"Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation"
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
(https://arxiv.org/abs1802.02611)
"Rethinking Atrous Convolution for Semantic Image Segmentation,"
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
(https://arxiv.org/abs/1706.05587)
"DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
Atrous Convolution, and Fully Connected CRFs",
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
Alan L Yuille (* equal contribution)
(https://arxiv.org/abs/1606.00915)
"Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected
CRFs"
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
Alan L. Yuille (* equal contribution)
(https://arxiv.org/abs/1412.7062)
"""
import collections
import tensorflow as tf
from deeplab import model
from feelvos import common
from feelvos.utils import embedding_utils
from feelvos.utils import train_utils
slim = tf.contrib.slim
get_branch_logits = model.get_branch_logits
get_extra_layer_scopes = model.get_extra_layer_scopes
multi_scale_logits_v2 = model.multi_scale_logits
refine_by_decoder = model.refine_by_decoder
scale_dimension = model.scale_dimension
split_separable_conv2d = model.split_separable_conv2d
MERGED_LOGITS_SCOPE = model.MERGED_LOGITS_SCOPE
IMAGE_POOLING_SCOPE = model.IMAGE_POOLING_SCOPE
ASPP_SCOPE = model.ASPP_SCOPE
CONCAT_PROJECTION_SCOPE = model.CONCAT_PROJECTION_SCOPE
def predict_labels(images,
model_options,
image_pyramid=None,
reference_labels=None,
k_nearest_neighbors=1,
embedding_dimension=None,
use_softmax_feedback=False,
initial_softmax_feedback=None,
embedding_seg_feature_dimension=256,
embedding_seg_n_layers=4,
embedding_seg_kernel_size=7,
embedding_seg_atrous_rates=None,
also_return_softmax_probabilities=False,
num_frames_per_video=None,
normalize_nearest_neighbor_distances=False,
also_attend_to_previous_frame=False,
use_local_previous_frame_attention=False,
previous_frame_attention_window_size=9,
use_first_frame_matching=True,
also_return_embeddings=False,
ref_embeddings=None):
"""Predicts segmentation labels.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: An InternalModelOptions instance to configure models.
image_pyramid: Input image scales for multi-scale feature extraction.
reference_labels: A tensor of size [batch, height, width, 1].
ground truth labels used to perform a nearest neighbor query
k_nearest_neighbors: Integer, the number of neighbors to use for nearest
neighbor queries.
embedding_dimension: Integer, the dimension used for the learned embedding.
use_softmax_feedback: Boolean, whether to give the softmax predictions of
the last frame as additional input to the segmentation head.
initial_softmax_feedback: Float32 tensor, or None. Can be used to
initialize the softmax predictions used for the feedback loop.
Typically only useful for inference. Only has an effect if
use_softmax_feedback is True.
embedding_seg_feature_dimension: Integer, the dimensionality used in the
segmentation head layers.
embedding_seg_n_layers: Integer, the number of layers in the segmentation
head.
embedding_seg_kernel_size: Integer, the kernel size used in the
segmentation head.
embedding_seg_atrous_rates: List of integers of length
embedding_seg_n_layers, the atrous rates to use for the segmentation head.
also_return_softmax_probabilities: Boolean, if true, additionally return
the softmax probabilities as second return value.
num_frames_per_video: Integer, the number of frames per video.
normalize_nearest_neighbor_distances: Boolean, whether to normalize the
nearest neighbor distances to [0,1] using sigmoid, scale and shift.
also_attend_to_previous_frame: Boolean, whether to also use nearest
neighbor attention with respect to the previous frame.
use_local_previous_frame_attention: Boolean, whether to restrict the
previous frame attention to a local search window.
Only has an effect, if also_attend_to_previous_frame is True.
previous_frame_attention_window_size: Integer, the window size used for
local previous frame attention, if use_local_previous_frame_attention
is True.
use_first_frame_matching: Boolean, whether to extract features by matching
to the reference frame. This should always be true except for ablation
experiments.
also_return_embeddings: Boolean, whether to return the embeddings as well.
ref_embeddings: Tuple of
(first_frame_embeddings, previous_frame_embeddings),
each of shape [batch, height, width, embedding_dimension], or None.
Returns:
A dictionary with keys specifying the output_type (e.g., semantic
prediction) and values storing Tensors representing predictions (argmax
over channels). Each prediction has size [batch, height, width].
If also_return_softmax_probabilities is True, the second return value are
the softmax probabilities.
If also_return_embeddings is True, it will also return an embeddings
tensor of shape [batch, height, width, embedding_dimension].
Raises:
ValueError: If classification_loss is not softmax, softmax_with_attention,
nor triplet.
"""
if (model_options.classification_loss == 'triplet' and
reference_labels is None):
raise ValueError('Need reference_labels for triplet loss')
if model_options.classification_loss == 'softmax_with_attention':
if embedding_dimension is None:
raise ValueError('Need embedding_dimension for softmax_with_attention '
'loss')
if reference_labels is None:
raise ValueError('Need reference_labels for softmax_with_attention loss')
res = (
multi_scale_logits_with_nearest_neighbor_matching(
images,
model_options=model_options,
image_pyramid=image_pyramid,
is_training=False,
reference_labels=reference_labels,
clone_batch_size=1,
num_frames_per_video=num_frames_per_video,
embedding_dimension=embedding_dimension,
max_neighbors_per_object=0,
k_nearest_neighbors=k_nearest_neighbors,
use_softmax_feedback=use_softmax_feedback,
initial_softmax_feedback=initial_softmax_feedback,
embedding_seg_feature_dimension=embedding_seg_feature_dimension,
embedding_seg_n_layers=embedding_seg_n_layers,
embedding_seg_kernel_size=embedding_seg_kernel_size,
embedding_seg_atrous_rates=embedding_seg_atrous_rates,
normalize_nearest_neighbor_distances=
normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=also_attend_to_previous_frame,
use_local_previous_frame_attention=
use_local_previous_frame_attention,
previous_frame_attention_window_size=
previous_frame_attention_window_size,
use_first_frame_matching=use_first_frame_matching,
also_return_embeddings=also_return_embeddings,
ref_embeddings=ref_embeddings
))
if also_return_embeddings:
outputs_to_scales_to_logits, embeddings = res
else:
outputs_to_scales_to_logits = res
embeddings = None
else:
outputs_to_scales_to_logits = multi_scale_logits_v2(
images,
model_options=model_options,
image_pyramid=image_pyramid,
is_training=False,
fine_tune_batch_norm=False)
predictions = {}
for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output]
original_logits = scales_to_logits[MERGED_LOGITS_SCOPE]
if isinstance(original_logits, list):
assert len(original_logits) == 1
original_logits = original_logits[0]
logits = tf.image.resize_bilinear(original_logits, tf.shape(images)[1:3],
align_corners=True)
if model_options.classification_loss in ('softmax',
'softmax_with_attention'):
predictions[output] = tf.argmax(logits, 3)
elif model_options.classification_loss == 'triplet':
# to keep this fast, we do the nearest neighbor assignment on the
# resolution at which the embedding is extracted and scale the result up
# afterwards
embeddings = original_logits
reference_labels_logits_size = tf.squeeze(
tf.image.resize_nearest_neighbor(
reference_labels[tf.newaxis],
train_utils.resolve_shape(embeddings)[1:3],
align_corners=True), axis=0)
nn_labels = embedding_utils.assign_labels_by_nearest_neighbors(
embeddings[0], embeddings[1:], reference_labels_logits_size,
k_nearest_neighbors)
predictions[common.OUTPUT_TYPE] = tf.image.resize_nearest_neighbor(
nn_labels, tf.shape(images)[1:3], align_corners=True)
else:
raise ValueError(
'Only support softmax, triplet, or softmax_with_attention for '
'classification_loss.')
if also_return_embeddings:
assert also_return_softmax_probabilities
return predictions, tf.nn.softmax(original_logits, axis=-1), embeddings
elif also_return_softmax_probabilities:
return predictions, tf.nn.softmax(original_logits, axis=-1)
else:
return predictions
def multi_scale_logits_with_nearest_neighbor_matching(
images,
model_options,
image_pyramid,
clone_batch_size,
reference_labels,
num_frames_per_video,
embedding_dimension,
max_neighbors_per_object,
weight_decay=0.0001,
is_training=False,
fine_tune_batch_norm=False,
k_nearest_neighbors=1,
use_softmax_feedback=False,
initial_softmax_feedback=None,
embedding_seg_feature_dimension=256,
embedding_seg_n_layers=4,
embedding_seg_kernel_size=7,
embedding_seg_atrous_rates=None,
normalize_nearest_neighbor_distances=False,
also_attend_to_previous_frame=False,
damage_initial_previous_frame_mask=False,
use_local_previous_frame_attention=False,
previous_frame_attention_window_size=9,
use_first_frame_matching=True,
also_return_embeddings=False,
ref_embeddings=None):
"""Gets the logits for multi-scale inputs using nearest neighbor attention.
Adjusted version of multi_scale_logits_v2 to support nearest neighbor
attention and a variable number of classes for each element of the batch.
The returned logits are all downsampled (due to max-pooling layers)
for both training and evaluation.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
image_pyramid: Input image scales for multi-scale feature extraction.
clone_batch_size: Integer, the number of videos on a batch.
reference_labels: The segmentation labels of the reference frame on which
attention is applied.
num_frames_per_video: Integer, the number of frames per video.
embedding_dimension: Integer, the dimension of the embedding.
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling.
Can be 0 for no subsampling.
weight_decay: The weight decay for model variables.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
use_softmax_feedback: Boolean, whether to give the softmax predictions of
the last frame as additional input to the segmentation head.
initial_softmax_feedback: List of Float32 tensors, or None.
Can be used to initialize the softmax predictions used for the feedback
loop. Only has an effect if use_softmax_feedback is True.
embedding_seg_feature_dimension: Integer, the dimensionality used in the
segmentation head layers.
embedding_seg_n_layers: Integer, the number of layers in the segmentation
head.
embedding_seg_kernel_size: Integer, the kernel size used in the
segmentation head.
embedding_seg_atrous_rates: List of integers of length
embedding_seg_n_layers, the atrous rates to use for the segmentation head.
normalize_nearest_neighbor_distances: Boolean, whether to normalize the
nearest neighbor distances to [0,1] using sigmoid, scale and shift.
also_attend_to_previous_frame: Boolean, whether to also use nearest
neighbor attention with respect to the previous frame.
damage_initial_previous_frame_mask: Boolean, whether to artificially damage
the initial previous frame mask. Only has an effect if
also_attend_to_previous_frame is True.
use_local_previous_frame_attention: Boolean, whether to restrict the
previous frame attention to a local search window.
Only has an effect, if also_attend_to_previous_frame is True.
previous_frame_attention_window_size: Integer, the window size used for
local previous frame attention, if use_local_previous_frame_attention
is True.
use_first_frame_matching: Boolean, whether to extract features by matching
to the reference frame. This should always be true except for ablation
experiments.
also_return_embeddings: Boolean, whether to return the embeddings as well.
ref_embeddings: Tuple of
(first_frame_embeddings, previous_frame_embeddings),
each of shape [batch, height, width, embedding_dimension], or None.
Returns:
outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
semantic prediction) to a dictionary of multi-scale logits names to
logits. For each output_type, the dictionary has keys which
correspond to the scales and values which correspond to the logits.
For example, if `scales` equals [1.0, 1.5], then the keys would
include 'merged_logits', 'logits_1.00' and 'logits_1.50'.
If also_return_embeddings is True, it will also return an embeddings
tensor of shape [batch, height, width, embedding_dimension].
Raises:
ValueError: If model_options doesn't specify crop_size and its
add_image_level_feature = True, since add_image_level_feature requires
crop_size information.
"""
# Setup default values.
if not image_pyramid:
image_pyramid = [1.0]
crop_height = (
model_options.crop_size[0]
if model_options.crop_size else tf.shape(images)[1])
crop_width = (
model_options.crop_size[1]
if model_options.crop_size else tf.shape(images)[2])
# Compute the height, width for the output logits.
if model_options.decoder_output_stride:
logits_output_stride = min(model_options.decoder_output_stride)
else:
logits_output_stride = model_options.output_stride
logits_height = scale_dimension(
crop_height,
max(1.0, max(image_pyramid)) / logits_output_stride)
logits_width = scale_dimension(
crop_width,
max(1.0, max(image_pyramid)) / logits_output_stride)
# Compute the logits for each scale in the image pyramid.
outputs_to_scales_to_logits = {
k: {}
for k in model_options.outputs_to_num_classes
}
for image_scale in image_pyramid:
if image_scale != 1.0:
scaled_height = scale_dimension(crop_height, image_scale)
scaled_width = scale_dimension(crop_width, image_scale)
scaled_crop_size = [scaled_height, scaled_width]
scaled_images = tf.image.resize_bilinear(
images, scaled_crop_size, align_corners=True)
scaled_reference_labels = tf.image.resize_nearest_neighbor(
reference_labels, scaled_crop_size, align_corners=True
)
if model_options.crop_size is None:
scaled_crop_size = None
if model_options.crop_size:
scaled_images.set_shape([None, scaled_height, scaled_width, 3])
else:
scaled_crop_size = model_options.crop_size
scaled_images = images
scaled_reference_labels = reference_labels
updated_options = model_options._replace(crop_size=scaled_crop_size)
res = embedding_utils.get_logits_with_matching(
scaled_images,
updated_options,
weight_decay=weight_decay,
reuse=tf.AUTO_REUSE,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm,
reference_labels=scaled_reference_labels,
batch_size=clone_batch_size,
num_frames_per_video=num_frames_per_video,
embedding_dimension=embedding_dimension,
max_neighbors_per_object=max_neighbors_per_object,
k_nearest_neighbors=k_nearest_neighbors,
use_softmax_feedback=use_softmax_feedback,
initial_softmax_feedback=initial_softmax_feedback,
embedding_seg_feature_dimension=embedding_seg_feature_dimension,
embedding_seg_n_layers=embedding_seg_n_layers,
embedding_seg_kernel_size=embedding_seg_kernel_size,
embedding_seg_atrous_rates=embedding_seg_atrous_rates,
normalize_nearest_neighbor_distances=
normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=also_attend_to_previous_frame,
damage_initial_previous_frame_mask=damage_initial_previous_frame_mask,
use_local_previous_frame_attention=use_local_previous_frame_attention,
previous_frame_attention_window_size=
previous_frame_attention_window_size,
use_first_frame_matching=use_first_frame_matching,
also_return_embeddings=also_return_embeddings,
ref_embeddings=ref_embeddings
)
if also_return_embeddings:
outputs_to_logits, embeddings = res
else:
outputs_to_logits = res
embeddings = None
# Resize the logits to have the same dimension before merging.
for output in sorted(outputs_to_logits):
if isinstance(outputs_to_logits[output], collections.Sequence):
outputs_to_logits[output] = [tf.image.resize_bilinear(
x, [logits_height, logits_width], align_corners=True)
for x in outputs_to_logits[output]]
else:
outputs_to_logits[output] = tf.image.resize_bilinear(
outputs_to_logits[output], [logits_height, logits_width],
align_corners=True)
# Return when only one input scale.
if len(image_pyramid) == 1:
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_scales_to_logits[output][
MERGED_LOGITS_SCOPE] = outputs_to_logits[output]
if also_return_embeddings:
return outputs_to_scales_to_logits, embeddings
else:
return outputs_to_scales_to_logits
# Save logits to the output map.
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_scales_to_logits[output][
'logits_%.2f' % image_scale] = outputs_to_logits[output]
# Merge the logits from all the multi-scale inputs.
for output in sorted(model_options.outputs_to_num_classes):
# Concatenate the multi-scale logits for each output type.
all_logits = [
[tf.expand_dims(l, axis=4)]
for logits in outputs_to_scales_to_logits[output].values()
for l in logits
]
transposed = map(list, zip(*all_logits))
all_logits = [tf.concat(t, 4) for t in transposed]
merge_fn = (
tf.reduce_max
if model_options.merge_method == 'max' else tf.reduce_mean)
outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = [merge_fn(
l, axis=4) for l in all_logits]
if also_return_embeddings:
return outputs_to_scales_to_logits, embeddings
else:
return outputs_to_scales_to_logits
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training script for the FEELVOS model.
See model.py for more details and usage.
"""
import six
import tensorflow as tf
from feelvos import common
from feelvos import model
from feelvos.datasets import video_dataset
from feelvos.utils import embedding_utils
from feelvos.utils import train_utils
from feelvos.utils import video_input_generator
from deployment import model_deploy
slim = tf.contrib.slim
prefetch_queue = slim.prefetch_queue
flags = tf.app.flags
FLAGS = flags.FLAGS
# Settings for multi-GPUs/multi-replicas training.
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
flags.DEFINE_integer('startup_delay_steps', 15,
'Number of training steps between replicas startup.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID.')
# Settings for logging.
flags.DEFINE_string('train_logdir', None,
'Where the checkpoint and logs are stored.')
flags.DEFINE_integer('log_steps', 10,
'Display logging information at every log_steps.')
flags.DEFINE_integer('save_interval_secs', 1200,
'How often, in seconds, we save the model to disk.')
flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
# Settings for training strategy.
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
'Learning rate policy for training.')
flags.DEFINE_float('base_learning_rate', 0.0007,
'The base learning rate for model training.')
flags.DEFINE_float('learning_rate_decay_factor', 0.1,
'The rate to decay the base learning rate.')
flags.DEFINE_integer('learning_rate_decay_step', 2000,
'Decay the base learning rate at a fixed step.')
flags.DEFINE_float('learning_power', 0.9,
'The power value used in the poly learning policy.')
flags.DEFINE_integer('training_number_of_steps', 200000,
'The number of steps used for training')
flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
flags.DEFINE_integer('train_batch_size', 6,
'The number of images in each batch during training.')
flags.DEFINE_integer('train_num_frames_per_video', 3,
'The number of frames used per video during training')
flags.DEFINE_float('weight_decay', 0.00004,
'The value of the weight decay for training.')
flags.DEFINE_multi_integer('train_crop_size', [465, 465],
'Image crop size [height, width] during training.')
flags.DEFINE_float('last_layer_gradient_multiplier', 1.0,
'The gradient multiplier for last layers, which is used to '
'boost the gradient of last layers if the value > 1.')
flags.DEFINE_boolean('upsample_logits', True,
'Upsample logits during training.')
flags.DEFINE_integer('batch_capacity_factor', 16, 'Batch capacity factor.')
flags.DEFINE_integer('num_readers', 1, 'Number of readers for data provider.')
flags.DEFINE_integer('batch_num_threads', 1, 'Batch number of threads.')
flags.DEFINE_integer('prefetch_queue_capacity_factor', 32,
'Prefetch queue capacity factor.')
flags.DEFINE_integer('prefetch_queue_num_threads', 1,
'Prefetch queue number of threads.')
flags.DEFINE_integer('train_max_neighbors_per_object', 1024,
'The maximum number of candidates for the nearest '
'neighbor query per object after subsampling')
# Settings for fine-tuning the network.
flags.DEFINE_string('tf_initial_checkpoint', None,
'The initial checkpoint in tensorflow format.')
flags.DEFINE_boolean('initialize_last_layer', False,
'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', False,
'Only consider logits as last layers or not.')
flags.DEFINE_integer('slow_start_step', 0,
'Training model with small learning rate for few steps.')
flags.DEFINE_float('slow_start_learning_rate', 1e-4,
'Learning rate employed during slow start.')
flags.DEFINE_boolean('fine_tune_batch_norm', False,
'Fine tune the batch norm parameters or not.')
flags.DEFINE_float('min_scale_factor', 1.,
'Mininum scale factor for data augmentation.')
flags.DEFINE_float('max_scale_factor', 1.3,
'Maximum scale factor for data augmentation.')
flags.DEFINE_float('scale_factor_step_size', 0,
'Scale factor step size for data augmentation.')
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_integer('output_stride', 8,
'The ratio of input to output spatial resolution.')
flags.DEFINE_boolean('sample_only_first_frame_for_finetuning', False,
'Whether to only sample the first frame during '
'fine-tuning. This should be False when using lucid data, '
'but True when fine-tuning on the first frame only. Only '
'has an effect if first_frame_finetuning is True.')
flags.DEFINE_multi_integer('first_frame_finetuning', [0],
'Whether to only sample the first frame for '
'fine-tuning.')
# Dataset settings.
flags.DEFINE_multi_string('dataset', [], 'Name of the segmentation datasets.')
flags.DEFINE_multi_float('dataset_sampling_probabilities', [],
'A list of probabilities to sample each of the '
'datasets.')
flags.DEFINE_string('train_split', 'train',
'Which split of the dataset to be used for training')
flags.DEFINE_multi_string('dataset_dir', [], 'Where the datasets reside.')
flags.DEFINE_multi_integer('three_frame_dataset', [0],
'Whether the dataset has exactly three frames per '
'video of which the first is to be used as reference'
' and the two others are consecutive frames to be '
'used as query frames.'
'Set true for pascal lucid data.')
flags.DEFINE_boolean('damage_initial_previous_frame_mask', False,
'Whether to artificially damage the initial previous '
'frame mask. Only has an effect if '
'also_attend_to_previous_frame is True.')
flags.DEFINE_float('top_k_percent_pixels', 0.15, 'Float in [0.0, 1.0].'
'When its value < 1.0, only compute the loss for the top k'
'percent pixels (e.g., the top 20% pixels). This is useful'
'for hard pixel mining.')
flags.DEFINE_integer('hard_example_mining_step', 100000,
'The training step in which the hard exampling mining '
'kicks off. Note that we gradually reduce the mining '
'percent to the top_k_percent_pixels. For example, if '
'hard_example_mining_step=100K and '
'top_k_percent_pixels=0.25, then mining percent will '
'gradually reduce from 100% to 25% until 100K steps '
'after which we only mine top 25% pixels. Only has an '
'effect if top_k_percent_pixels < 1.0')
def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
ignore_label):
"""Builds a clone of DeepLab.
Args:
inputs_queue_or_samples: A prefetch queue for images and labels, or
directly a dict of the samples.
outputs_to_num_classes: A map from output type to the number of classes.
For example, for the task of semantic segmentation with 21 semantic
classes, we would have outputs_to_num_classes['semantic'] = 21.
ignore_label: Ignore label.
Returns:
A map of maps from output_type (e.g., semantic prediction) to a
dictionary of multi-scale logits names to logits. For each output_type,
the dictionary has keys which correspond to the scales and values which
correspond to the logits. For example, if `scales` equals [1.0, 1.5],
then the keys would include 'merged_logits', 'logits_1.00' and
'logits_1.50'.
Raises:
ValueError: If classification_loss is not softmax, softmax_with_attention,
or triplet.
"""
if hasattr(inputs_queue_or_samples, 'dequeue'):
samples = inputs_queue_or_samples.dequeue()
else:
samples = inputs_queue_or_samples
train_crop_size = (None if 0 in FLAGS.train_crop_size else
FLAGS.train_crop_size)
model_options = common.VideoModelOptions(
outputs_to_num_classes=outputs_to_num_classes,
crop_size=train_crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if model_options.classification_loss == 'softmax_with_attention':
clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
# Create summaries of ground truth labels.
for n in range(clone_batch_size):
tf.summary.image(
'gt_label_%d' % n,
tf.cast(samples[common.LABEL][
n * FLAGS.train_num_frames_per_video:
(n + 1) * FLAGS.train_num_frames_per_video],
tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video)
if common.PRECEDING_FRAME_LABEL in samples:
preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL]
init_softmax = []
for n in range(clone_batch_size):
init_softmax_n = embedding_utils.create_initial_softmax_from_labels(
preceding_frame_label[n, tf.newaxis],
samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
tf.newaxis],
common.parse_decoder_output_stride(),
reduce_labels=True)
init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
init_softmax.append(init_softmax_n)
tf.summary.image('preceding_frame_label',
tf.cast(preceding_frame_label[n, tf.newaxis],
tf.uint8) * 32)
else:
init_softmax = None
outputs_to_scales_to_logits = (
model.multi_scale_logits_with_nearest_neighbor_matching(
samples[common.IMAGE],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
weight_decay=FLAGS.weight_decay,
is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
reference_labels=samples[common.LABEL],
clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones,
num_frames_per_video=FLAGS.train_num_frames_per_video,
embedding_dimension=FLAGS.embedding_dimension,
max_neighbors_per_object=FLAGS.train_max_neighbors_per_object,
k_nearest_neighbors=FLAGS.k_nearest_neighbors,
use_softmax_feedback=FLAGS.use_softmax_feedback,
initial_softmax_feedback=init_softmax,
embedding_seg_feature_dimension=
FLAGS.embedding_seg_feature_dimension,
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
normalize_nearest_neighbor_distances=
FLAGS.normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
damage_initial_previous_frame_mask=
FLAGS.damage_initial_previous_frame_mask,
use_local_previous_frame_attention=
FLAGS.use_local_previous_frame_attention,
previous_frame_attention_window_size=
FLAGS.previous_frame_attention_window_size,
use_first_frame_matching=FLAGS.use_first_frame_matching
))
else:
outputs_to_scales_to_logits = model.multi_scale_logits_v2(
samples[common.IMAGE],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
weight_decay=FLAGS.weight_decay,
is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
if model_options.classification_loss == 'softmax':
for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output],
samples[common.LABEL],
num_classes,
ignore_label,
loss_weight=1.0,
upsample_logits=FLAGS.upsample_logits,
scope=output)
elif model_options.classification_loss == 'triplet':
for output, _ in six.iteritems(outputs_to_num_classes):
train_utils.add_triplet_loss_for_each_scale(
FLAGS.train_batch_size // FLAGS.num_clones,
FLAGS.train_num_frames_per_video,
FLAGS.embedding_dimension, outputs_to_scales_to_logits[output],
samples[common.LABEL], scope=output)
elif model_options.classification_loss == 'softmax_with_attention':
labels = samples[common.LABEL]
batch_size = FLAGS.train_batch_size // FLAGS.num_clones
num_frames_per_video = FLAGS.train_num_frames_per_video
h, w = train_utils.resolve_shape(labels)[1:3]
labels = tf.reshape(labels, tf.stack(
[batch_size, num_frames_per_video, h, w, 1]))
# Strip the reference labels off.
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
n_ref_frames = 2
else:
n_ref_frames = 1
labels = labels[:, n_ref_frames:]
# Merge batch and time dimensions.
labels = tf.reshape(labels, tf.stack(
[batch_size * (num_frames_per_video - n_ref_frames), h, w, 1]))
for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output],
labels,
ignore_label,
loss_weight=1.0,
upsample_logits=FLAGS.upsample_logits,
scope=output,
top_k_percent_pixels=FLAGS.top_k_percent_pixels,
hard_example_mining_step=FLAGS.hard_example_mining_step)
else:
raise ValueError('Only support softmax, softmax_with_attention'
' or triplet for classification_loss.')
return outputs_to_scales_to_logits
def main(unused_argv):
# Set up deployment (i.e., multi-GPUs and/or multi-replicas).
config = model_deploy.DeploymentConfig(
num_clones=FLAGS.num_clones,
clone_on_cpu=FLAGS.clone_on_cpu,
replica_id=FLAGS.task,
num_replicas=FLAGS.num_replicas,
num_ps_tasks=FLAGS.num_ps_tasks)
with tf.Graph().as_default():
with tf.device(config.inputs_device()):
train_crop_size = (None if 0 in FLAGS.train_crop_size else
FLAGS.train_crop_size)
assert FLAGS.dataset
assert len(FLAGS.dataset) == len(FLAGS.dataset_dir)
if len(FLAGS.first_frame_finetuning) == 1:
first_frame_finetuning = (list(FLAGS.first_frame_finetuning)
* len(FLAGS.dataset))
else:
first_frame_finetuning = FLAGS.first_frame_finetuning
if len(FLAGS.three_frame_dataset) == 1:
three_frame_dataset = (list(FLAGS.three_frame_dataset)
* len(FLAGS.dataset))
else:
three_frame_dataset = FLAGS.three_frame_dataset
assert len(FLAGS.dataset) == len(first_frame_finetuning)
assert len(FLAGS.dataset) == len(three_frame_dataset)
datasets, samples_list = zip(
*[_get_dataset_and_samples(config, train_crop_size, dataset,
dataset_dir, bool(first_frame_finetuning_),
bool(three_frame_dataset_))
for dataset, dataset_dir, first_frame_finetuning_,
three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir,
first_frame_finetuning,
three_frame_dataset)])
# Note that this way of doing things is wasteful since it will evaluate
# all branches but just use one of them. But let's do it anyway for now,
# since it's easy and will probably be fast enough.
dataset = datasets[0]
if len(samples_list) == 1:
samples = samples_list[0]
else:
probabilities = FLAGS.dataset_sampling_probabilities
if probabilities:
assert len(probabilities) == len(samples_list)
else:
# Default to uniform probabilities.
probabilities = [1.0 / len(samples_list) for _ in samples_list]
probabilities = tf.constant(probabilities)
logits = tf.log(probabilities[tf.newaxis])
rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32),
axis=[0, 1])
def wrap(x):
def f():
return x
return f
samples = tf.case({tf.equal(rand_idx, idx): wrap(s)
for idx, s in enumerate(samples_list)},
exclusive=True)
# Prefetch_queue requires the shape to be known at graph creation time.
# So we only use it if we crop to a fixed size.
if train_crop_size is None:
inputs_queue = samples
else:
inputs_queue = prefetch_queue.prefetch_queue(
samples,
capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones,
num_threads=FLAGS.prefetch_queue_num_threads)
# Create the global step on the device storing the variables.
with tf.device(config.variables_device()):
global_step = tf.train.get_or_create_global_step()
# Define the model and create clones.
model_fn = _build_deeplab
if FLAGS.classification_loss == 'triplet':
embedding_dim = FLAGS.embedding_dimension
output_type_to_dim = {'embedding': embedding_dim}
else:
output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes}
model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label)
clones = model_deploy.create_clones(config, model_fn, args=model_args)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
first_clone_scope = config.clone_scope(0)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
# Gather initial summaries.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
# Add summaries for model variables.
for model_var in tf.contrib.framework.get_model_variables():
summaries.add(tf.summary.histogram(model_var.op.name, model_var))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
# Build the optimizer based on the device specification.
with tf.device(config.optimizer_device()):
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy,
FLAGS.base_learning_rate,
FLAGS.learning_rate_decay_step,
FLAGS.learning_rate_decay_factor,
FLAGS.training_number_of_steps,
FLAGS.learning_power,
FLAGS.slow_start_step,
FLAGS.slow_start_learning_rate)
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
with tf.device(config.variables_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones(
clones, optimizer)
total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult:
grads_and_vars = slim.learning.multiply_gradients(grads_and_vars,
grad_mult)
with tf.name_scope('grad_clipping'):
grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0)
# Create histogram summaries for the gradients.
# We have too many summaries for mldash, so disable this one for now.
# for grad, var in grads_and_vars:
# summaries.add(tf.summary.histogram(
# var.name.replace(':0', '_0') + '/gradient', grad))
# Create gradient update op.
grad_updates = optimizer.apply_gradients(grads_and_vars,
global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
first_clone_scope))
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries))
# Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
# Start the training.
slim.learning.train(
train_tensor,
logdir=FLAGS.train_logdir,
log_every_n_steps=FLAGS.log_steps,
master=FLAGS.master,
number_of_steps=FLAGS.training_number_of_steps,
is_chief=(FLAGS.task == 0),
session_config=session_config,
startup_delay_steps=startup_delay_steps,
init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir,
FLAGS.tf_initial_checkpoint,
FLAGS.initialize_last_layer,
last_layers,
ignore_missing_vars=True),
summary_op=summary_op,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
def _get_dataset_and_samples(config, train_crop_size, dataset_name,
dataset_dir, first_frame_finetuning,
three_frame_dataset):
"""Creates dataset object and samples dict of tensor.
Args:
config: A DeploymentConfig.
train_crop_size: Integer, the crop size used for training.
dataset_name: String, the name of the dataset.
dataset_dir: String, the directory of the dataset.
first_frame_finetuning: Boolean, whether the used dataset is a dataset
for first frame fine-tuning.
three_frame_dataset: Boolean, whether the dataset has exactly three frames
per video of which the first is to be used as reference and the two
others are consecutive frames to be used as query frames.
Returns:
dataset: An instance of slim Dataset.
samples: A dictionary of tensors for semantic segmentation.
"""
# Split the batch across GPUs.
assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size / config.num_clones
if first_frame_finetuning:
train_split = 'val'
else:
train_split = FLAGS.train_split
data_type = 'tf_sequence_example'
# Get dataset-dependent information.
dataset = video_dataset.get_dataset(
dataset_name,
train_split,
dataset_dir=dataset_dir,
data_type=data_type)
tf.gfile.MakeDirs(FLAGS.train_logdir)
tf.logging.info('Training on %s set', train_split)
samples = video_input_generator.get(
dataset,
FLAGS.train_num_frames_per_video,
train_crop_size,
clone_batch_size,
num_readers=FLAGS.num_readers,
num_threads=FLAGS.batch_num_threads,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
dataset_split=FLAGS.train_split,
is_training=True,
model_variant=FLAGS.model_variant,
batch_capacity_factor=FLAGS.batch_capacity_factor,
decoder_output_stride=common.parse_decoder_output_stride(),
first_frame_finetuning=first_frame_finetuning,
sample_only_first_frame_for_finetuning=
FLAGS.sample_only_first_frame_for_finetuning,
sample_adjacent_and_consistent_query_frames=
FLAGS.sample_adjacent_and_consistent_query_frames or
FLAGS.use_softmax_feedback,
remap_labels_to_reference_frame=True,
three_frame_dataset=three_frame_dataset,
add_prev_frame_label=not FLAGS.also_attend_to_previous_frame
)
return dataset, samples
if __name__ == '__main__':
flags.mark_flag_as_required('train_logdir')
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script is used to run local training on DAVIS 2017. Users could also
# modify from this script for their use case. See eval.sh for an example of
# local inference with a pre-trained model.
#
# Note that this script runs local training with a single GPU and a smaller crop
# and batch size, while in the paper, we trained our models with 16 GPUS with
# --num_clones=2, --train_batch_size=6, --num_replicas=8,
# --training_number_of_steps=200000, --train_crop_size=465,
# --train_crop_size=465.
#
# Usage:
# # From the tensorflow/models/research/feelvos directory.
# sh ./train.sh
#
#
# Exit immediately if a command exits with a non-zero status.
set -e
# Move one-level up to tensorflow/models/research directory.
cd ..
# Update PYTHONPATH.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim:`pwd`/feelvos
# Set up the working environment.
CURRENT_DIR=$(pwd)
WORK_DIR="${CURRENT_DIR}/feelvos"
# Set up the working directories.
DATASET_DIR="datasets"
DAVIS_FOLDER="davis17"
DAVIS_DATASET="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/tfrecord"
EXP_FOLDER="exp/train"
TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/${EXP_FOLDER}/train"
mkdir -p ${TRAIN_LOGDIR}
# Go to datasets folder and download and convert the DAVIS 2017 dataset.
DATASET_DIR="datasets"
cd "${WORK_DIR}/${DATASET_DIR}"
sh download_and_convert_davis17.sh
# Go to models folder and download and unpack the COCO pre-trained model.
MODELS_DIR="models"
mkdir -p "${WORK_DIR}/${MODELS_DIR}"
cd "${WORK_DIR}/${MODELS_DIR}"
if [ ! -d "xception_65_coco_pretrained" ]; then
wget http://download.tensorflow.org/models/xception_65_coco_pretrained_2018_10_02.tar.gz
tar -xvf xception_65_coco_pretrained_2018_10_02.tar.gz
rm xception_65_coco_pretrained_2018_10_02.tar.gz
fi
INIT_CKPT="${WORK_DIR}/${MODELS_DIR}/xception_65_coco_pretrained/x65-b2u1s2p-d48-2-3x256-sc-cr300k_init.ckpt"
# Go back to orignal directory.
cd "${CURRENT_DIR}"
python "${WORK_DIR}"/train.py \
--dataset=davis_2017 \
--dataset_dir="${DAVIS_DATASET}" \
--train_logdir="${TRAIN_LOGDIR}" \
--tf_initial_checkpoint="${INIT_CKPT}" \
--logtostderr \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--decoder_output_stride=4 \
--model_variant=xception_65 \
--multi_grid=1 \
--multi_grid=1 \
--multi_grid=1 \
--output_stride=16 \
--weight_decay=0.00004 \
--num_clones=1 \
--train_batch_size=1 \
--train_crop_size=300 \
--train_crop_size=300
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for the instance embedding for segmentation."""
import numpy as np
import tensorflow as tf
from deeplab import model
from deeplab.core import preprocess_utils
from feelvos.utils import mask_damaging
slim = tf.contrib.slim
resolve_shape = preprocess_utils.resolve_shape
WRONG_LABEL_PADDING_DISTANCE = 1e20
# With correlation_cost local matching will be much faster. But we provide a
# slow fallback for convenience.
USE_CORRELATION_COST = False
if USE_CORRELATION_COST:
# pylint: disable=g-import-not-at-top
from correlation_cost.python.ops import correlation_cost_op
def pairwise_distances(x, y):
"""Computes pairwise squared l2 distances between tensors x and y.
Args:
x: Tensor of shape [n, feature_dim].
y: Tensor of shape [m, feature_dim].
Returns:
Float32 distances tensor of shape [n, m].
"""
# d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
# = sum(x[i]^2, 1) + sum(y[j]^2, 1) - 2 * x[i] * y[j]'
xs = tf.reduce_sum(x * x, axis=1)[:, tf.newaxis]
ys = tf.reduce_sum(y * y, axis=1)[tf.newaxis, :]
d = xs + ys - 2 * tf.matmul(x, y, transpose_b=True)
return d
def pairwise_distances2(x, y):
"""Computes pairwise squared l2 distances between tensors x and y.
Naive implementation, high memory use. Could be useful to test the more
efficient implementation.
Args:
x: Tensor of shape [n, feature_dim].
y: Tensor of shape [m, feature_dim].
Returns:
distances of shape [n, m].
"""
return tf.reduce_sum(tf.squared_difference(
x[:, tf.newaxis], y[tf.newaxis, :]), axis=-1)
def cross_correlate(x, y, max_distance=9):
"""Efficiently computes the cross correlation of x and y.
Optimized implementation using correlation_cost.
Note that we do not normalize by the feature dimension.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 tensor of shape [height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('cross_correlation'):
corr = correlation_cost_op.correlation_cost(
x[tf.newaxis], y[tf.newaxis], kernel_size=1,
max_displacement=max_distance, stride_1=1, stride_2=1,
pad=max_distance)
corr = tf.squeeze(corr, axis=0)
# This correlation implementation takes the mean over the feature_dim,
# but we want sum here, so multiply by feature_dim.
feature_dim = resolve_shape(x)[-1]
corr *= feature_dim
return corr
def local_pairwise_distances(x, y, max_distance=9):
"""Computes pairwise squared l2 distances using a local search window.
Optimized implementation using correlation_cost.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 distances tensor of shape
[height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('local_pairwise_distances'):
# d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
# = sum(x[i]^2, -1) + sum(y[j]^2, -1) - 2 * x[i] * y[j]'
corr = cross_correlate(x, y, max_distance=max_distance)
xs = tf.reduce_sum(x * x, axis=2)[..., tf.newaxis]
ys = tf.reduce_sum(y * y, axis=2)[..., tf.newaxis]
ones_ys = tf.ones_like(ys)
ys = cross_correlate(ones_ys, ys, max_distance=max_distance)
d = xs + ys - 2 * corr
# Boundary should be set to Inf.
boundary = tf.equal(
cross_correlate(ones_ys, ones_ys, max_distance=max_distance), 0)
d = tf.where(boundary, tf.fill(tf.shape(d), tf.constant(np.float('inf'))),
d)
return d
def local_pairwise_distances2(x, y, max_distance=9):
"""Computes pairwise squared l2 distances using a local search window.
Naive implementation using map_fn.
Used as a slow fallback for when correlation_cost is not available.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 distances tensor of shape
[height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('local_pairwise_distances2'):
padding_val = 1e20
padded_y = tf.pad(y, [[max_distance, max_distance],
[max_distance, max_distance], [0, 0]],
constant_values=padding_val)
height, width, _ = resolve_shape(x)
dists = []
for y_start in range(2 * max_distance + 1):
y_end = y_start + height
y_slice = padded_y[y_start:y_end]
for x_start in range(2 * max_distance + 1):
x_end = x_start + width
offset_y = y_slice[:, x_start:x_end]
dist = tf.reduce_sum(tf.squared_difference(x, offset_y), axis=2)
dists.append(dist)
dists = tf.stack(dists, axis=2)
return dists
def majority_vote(labels):
"""Performs a label majority vote along axis 1.
Second try, hopefully this time more efficient.
We assume that the labels are contiguous starting from 0.
It will also work for non-contiguous labels, but be inefficient.
Args:
labels: Int tensor of shape [n, k]
Returns:
The majority of labels along axis 1
"""
max_label = tf.reduce_max(labels)
one_hot = tf.one_hot(labels, depth=max_label + 1)
summed = tf.reduce_sum(one_hot, axis=1)
majority = tf.argmax(summed, axis=1)
return majority
def assign_labels_by_nearest_neighbors(reference_embeddings, query_embeddings,
reference_labels, k=1):
"""Segments by nearest neighbor query wrt the reference frame.
Args:
reference_embeddings: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the reference frame
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames
reference_labels: Tensor of shape [height, width, 1], the class labels of
the reference frame
k: Integer, the number of nearest neighbors to use
Returns:
The labels of the nearest neighbors as [n_query_frames, height, width, 1]
tensor
Raises:
ValueError: If k < 1.
"""
if k < 1:
raise ValueError('k must be at least 1')
dists = flattened_pairwise_distances(reference_embeddings, query_embeddings)
if k == 1:
nn_indices = tf.argmin(dists, axis=1)[..., tf.newaxis]
else:
_, nn_indices = tf.nn.top_k(-dists, k, sorted=False)
reference_labels = tf.reshape(reference_labels, [-1])
nn_labels = tf.gather(reference_labels, nn_indices)
if k == 1:
nn_labels = tf.squeeze(nn_labels, axis=1)
else:
nn_labels = majority_vote(nn_labels)
height = tf.shape(reference_embeddings)[0]
width = tf.shape(reference_embeddings)[1]
n_query_frames = query_embeddings.shape[0]
nn_labels = tf.reshape(nn_labels, [n_query_frames, height, width, 1])
return nn_labels
def flattened_pairwise_distances(reference_embeddings, query_embeddings):
"""Calculates flattened tensor of pairwise distances between ref and query.
Args:
reference_embeddings: Tensor of shape [..., embedding_dim],
the embedding vectors for the reference frame
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames.
Returns:
A distance tensor of shape [reference_embeddings.size / embedding_dim,
query_embeddings.size / embedding_dim]
"""
embedding_dim = resolve_shape(query_embeddings)[-1]
reference_embeddings = tf.reshape(reference_embeddings, [-1, embedding_dim])
first_dim = -1
query_embeddings = tf.reshape(query_embeddings, [first_dim, embedding_dim])
dists = pairwise_distances(query_embeddings, reference_embeddings)
return dists
def nearest_neighbor_features_per_object(
reference_embeddings, query_embeddings, reference_labels,
max_neighbors_per_object, k_nearest_neighbors, gt_ids=None, n_chunks=100):
"""Calculates the distance to the nearest neighbor per object.
For every pixel of query_embeddings calculate the distance to the
nearest neighbor in the (possibly subsampled) reference_embeddings per object.
Args:
reference_embeddings: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames.
reference_labels: Tensor of shape [height, width, 1], the class labels of
the reference frame.
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling,
or 0 for no subsampling.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
gt_ids: Int tensor of shape [n_objs] of the sorted unique ground truth
ids in the first frame. If None, it will be derived from
reference_labels.
n_chunks: Integer, the number of chunks to use to save memory
(set to 1 for no chunking).
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[n_query_images, height, width, n_objects, feature_dim].
gt_ids: An int32 tensor of the unique sorted object ids present
in the reference labels.
"""
with tf.name_scope('nn_features_per_object'):
reference_labels_flat = tf.reshape(reference_labels, [-1])
if gt_ids is None:
ref_obj_ids, _ = tf.unique(reference_labels_flat)
ref_obj_ids = tf.contrib.framework.sort(ref_obj_ids)
gt_ids = ref_obj_ids
embedding_dim = resolve_shape(reference_embeddings)[-1]
reference_embeddings_flat = tf.reshape(reference_embeddings,
[-1, embedding_dim])
reference_embeddings_flat, reference_labels_flat = (
subsample_reference_embeddings_and_labels(reference_embeddings_flat,
reference_labels_flat,
gt_ids,
max_neighbors_per_object))
shape = resolve_shape(query_embeddings)
query_embeddings_flat = tf.reshape(query_embeddings, [-1, embedding_dim])
nn_features = _nearest_neighbor_features_per_object_in_chunks(
reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
gt_ids, k_nearest_neighbors, n_chunks)
nn_features_dim = resolve_shape(nn_features)[-1]
nn_features_reshaped = tf.reshape(nn_features,
tf.stack(shape[:3] + [tf.size(gt_ids),
nn_features_dim]))
return nn_features_reshaped, gt_ids
def _nearest_neighbor_features_per_object_in_chunks(
reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
ref_obj_ids, k_nearest_neighbors, n_chunks):
"""Calculates the nearest neighbor features per object in chunks to save mem.
Uses chunking to bound the memory use.
Args:
reference_embeddings_flat: Tensor of shape [n, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings_flat: Tensor of shape [m, embedding_dim], the embedding
vectors for the query frames.
reference_labels_flat: Tensor of shape [n], the class labels of the
reference frame.
ref_obj_ids: int tensor of unique object ids in the reference labels.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
n_chunks: Integer, the number of chunks to use to save memory
(set to 1 for no chunking).
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[m, n_objects, feature_dim].
"""
chunk_size = tf.cast(tf.ceil(tf.cast(tf.shape(query_embeddings_flat)[0],
tf.float32) / n_chunks), tf.int32)
wrong_label_mask = tf.not_equal(reference_labels_flat,
ref_obj_ids[:, tf.newaxis])
all_features = []
for n in range(n_chunks):
if n_chunks == 1:
query_embeddings_flat_chunk = query_embeddings_flat
else:
chunk_start = n * chunk_size
chunk_end = (n + 1) * chunk_size
query_embeddings_flat_chunk = query_embeddings_flat[chunk_start:chunk_end]
# Use control dependencies to make sure that the chunks are not processed
# in parallel which would prevent any peak memory savings.
with tf.control_dependencies(all_features):
features = _nn_features_per_object_for_chunk(
reference_embeddings_flat, query_embeddings_flat_chunk,
wrong_label_mask, k_nearest_neighbors
)
all_features.append(features)
if n_chunks == 1:
nn_features = all_features[0]
else:
nn_features = tf.concat(all_features, axis=0)
return nn_features
def _nn_features_per_object_for_chunk(
reference_embeddings, query_embeddings, wrong_label_mask,
k_nearest_neighbors):
"""Extracts features for each object using nearest neighbor attention.
Args:
reference_embeddings: Tensor of shape [n_chunk, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings: Tensor of shape [m_chunk, embedding_dim], the embedding
vectors for the query frames.
wrong_label_mask:
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[m_chunk, n_objects, feature_dim].
"""
reference_embeddings_key = reference_embeddings
query_embeddings_key = query_embeddings
dists = flattened_pairwise_distances(reference_embeddings_key,
query_embeddings_key)
dists = (dists[:, tf.newaxis, :] +
tf.cast(wrong_label_mask[tf.newaxis, :, :], tf.float32) *
WRONG_LABEL_PADDING_DISTANCE)
if k_nearest_neighbors == 1:
features = tf.reduce_min(dists, axis=2, keepdims=True)
else:
# Find the closest k and combine them according to attention_feature_type
dists, _ = tf.nn.top_k(-dists, k=k_nearest_neighbors)
dists = -dists
# If not enough real neighbors were found, pad with the farthest real
# neighbor.
valid_mask = tf.less(dists, WRONG_LABEL_PADDING_DISTANCE)
masked_dists = dists * tf.cast(valid_mask, tf.float32)
pad_dist = tf.tile(tf.reduce_max(masked_dists, axis=2)[..., tf.newaxis],
multiples=[1, 1, k_nearest_neighbors])
dists = tf.where(valid_mask, dists, pad_dist)
# take mean of distances
features = tf.reduce_mean(dists, axis=2, keepdims=True)
return features
def create_embedding_segmentation_features(features, feature_dimension,
n_layers, kernel_size, reuse,
atrous_rates=None):
"""Extracts features which can be used to estimate the final segmentation.
Args:
features: input features of shape [batch, height, width, features]
feature_dimension: Integer, the dimensionality used in the segmentation
head layers.
n_layers: Integer, the number of layers in the segmentation head.
kernel_size: Integer, the kernel size used in the segmentation head.
reuse: reuse mode for the variable_scope.
atrous_rates: List of integers of length n_layers, the atrous rates to use.
Returns:
Features to be used to estimate the segmentation labels of shape
[batch, height, width, embedding_seg_feat_dim].
"""
if atrous_rates is None or not atrous_rates:
atrous_rates = [1 for _ in range(n_layers)]
assert len(atrous_rates) == n_layers
with tf.variable_scope('embedding_seg', reuse=reuse):
for n in range(n_layers):
features = model.split_separable_conv2d(
features, feature_dimension, kernel_size=kernel_size,
rate=atrous_rates[n], scope='split_separable_conv2d_{}'.format(n))
return features
def add_image_summaries(images, nn_features, logits, batch_size,
prev_frame_nn_features=None):
"""Adds image summaries of input images, attention features and logits.
Args:
images: Image tensor of shape [batch, height, width, channels].
nn_features: Nearest neighbor attention features of shape
[batch_size, height, width, n_objects, 1].
logits: Float32 tensor of logits.
batch_size: Integer, the number of videos per clone per mini-batch.
prev_frame_nn_features: Nearest neighbor attention features wrt. the
last frame of shape [batch_size, height, width, n_objects, 1].
Can be None.
"""
# Separate reference and query images.
reshaped_images = tf.reshape(images, tf.stack(
[batch_size, -1] + resolve_shape(images)[1:]))
reference_images = reshaped_images[:, 0]
query_images = reshaped_images[:, 1:]
query_images_reshaped = tf.reshape(query_images, tf.stack(
[-1] + resolve_shape(images)[1:]))
tf.summary.image('ref_images', reference_images, max_outputs=batch_size)
tf.summary.image('query_images', query_images_reshaped, max_outputs=10)
predictions = tf.cast(
tf.argmax(logits, axis=-1), tf.uint8)[..., tf.newaxis]
# Scale up so that we can actually see something.
tf.summary.image('predictions', predictions * 32, max_outputs=10)
# We currently only show the first dimension of the features for background
# and the first foreground object.
tf.summary.image('nn_fg_features', nn_features[..., 0:1, 0],
max_outputs=batch_size)
if prev_frame_nn_features is not None:
tf.summary.image('nn_fg_features_prev', prev_frame_nn_features[..., 0:1, 0],
max_outputs=batch_size)
tf.summary.image('nn_bg_features', nn_features[..., 1:2, 0],
max_outputs=batch_size)
if prev_frame_nn_features is not None:
tf.summary.image('nn_bg_features_prev',
prev_frame_nn_features[..., 1:2, 0],
max_outputs=batch_size)
def get_embeddings(images, model_options, embedding_dimension):
"""Extracts embedding vectors for images. Should only be used for inference.
Args:
images: A tensor of shape [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
embedding_dimension: Integer, the dimension of the embedding.
Returns:
embeddings: A tensor of shape [batch, height, width, embedding_dimension].
"""
features, end_points = model.extract_features(
images,
model_options,
is_training=False)
if model_options.decoder_output_stride is not None:
decoder_output_stride = min(model_options.decoder_output_stride)
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
features = model.refine_by_decoder(
features,
end_points,
crop_size=[height, width],
decoder_output_stride=[decoder_output_stride],
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
is_training=False)
with tf.variable_scope('embedding'):
embeddings = split_separable_conv2d_with_identity_initializer(
features, embedding_dimension, scope='split_separable_conv2d')
return embeddings
def get_logits_with_matching(images,
model_options,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False,
reference_labels=None,
batch_size=None,
num_frames_per_video=None,
embedding_dimension=None,
max_neighbors_per_object=0,
k_nearest_neighbors=1,
use_softmax_feedback=True,
initial_softmax_feedback=None,
embedding_seg_feature_dimension=256,
embedding_seg_n_layers=4,
embedding_seg_kernel_size=7,
embedding_seg_atrous_rates=None,
normalize_nearest_neighbor_distances=True,
also_attend_to_previous_frame=True,
damage_initial_previous_frame_mask=False,
use_local_previous_frame_attention=True,
previous_frame_attention_window_size=15,
use_first_frame_matching=True,
also_return_embeddings=False,
ref_embeddings=None):
"""Gets the logits by atrous/image spatial pyramid pooling using attention.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
reference_labels: The segmentation labels of the reference frame on which
attention is applied.
batch_size: Integer, the number of videos on a batch
num_frames_per_video: Integer, the number of frames per video
embedding_dimension: Integer, the dimension of the embedding
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling.
Can be 0 for no subsampling.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
use_softmax_feedback: Boolean, whether to give the softmax predictions of
the last frame as additional input to the segmentation head.
initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
initialize the softmax predictions used for the feedback loop.
Only has an effect if use_softmax_feedback is True.
embedding_seg_feature_dimension: Integer, the dimensionality used in the
segmentation head layers.
embedding_seg_n_layers: Integer, the number of layers in the segmentation
head.
embedding_seg_kernel_size: Integer, the kernel size used in the
segmentation head.
embedding_seg_atrous_rates: List of integers of length
embedding_seg_n_layers, the atrous rates to use for the segmentation head.
normalize_nearest_neighbor_distances: Boolean, whether to normalize the
nearest neighbor distances to [0,1] using sigmoid, scale and shift.
also_attend_to_previous_frame: Boolean, whether to also use nearest
neighbor attention with respect to the previous frame.
damage_initial_previous_frame_mask: Boolean, whether to artificially damage
the initial previous frame mask. Only has an effect if
also_attend_to_previous_frame is True.
use_local_previous_frame_attention: Boolean, whether to restrict the
previous frame attention to a local search window.
Only has an effect, if also_attend_to_previous_frame is True.
previous_frame_attention_window_size: Integer, the window size used for
local previous frame attention, if use_local_previous_frame_attention
is True.
use_first_frame_matching: Boolean, whether to extract features by matching
to the reference frame. This should always be true except for ablation
experiments.
also_return_embeddings: Boolean, whether to return the embeddings as well.
ref_embeddings: Tuple of
(first_frame_embeddings, previous_frame_embeddings),
each of shape [batch, height, width, embedding_dimension], or None.
Returns:
outputs_to_logits: A map from output_type to logits.
If also_return_embeddings is True, it will also return an embeddings
tensor of shape [batch, height, width, embedding_dimension].
"""
features, end_points = model.extract_features(
images,
model_options,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
if model_options.decoder_output_stride:
decoder_output_stride = min(model_options.decoder_output_stride)
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
features = model.refine_by_decoder(
features,
end_points,
crop_size=[height, width],
decoder_output_stride=[decoder_output_stride],
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
with tf.variable_scope('embedding', reuse=reuse):
embeddings = split_separable_conv2d_with_identity_initializer(
features, embedding_dimension, scope='split_separable_conv2d')
embeddings = tf.identity(embeddings, name='embeddings')
scaled_reference_labels = tf.image.resize_nearest_neighbor(
reference_labels,
resolve_shape(embeddings, 4)[1:3],
align_corners=True)
h, w = decoder_height, decoder_width
if num_frames_per_video is None:
num_frames_per_video = tf.size(embeddings) // (
batch_size * h * w * embedding_dimension)
new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
reshaped_reference_labels = tf.reshape(scaled_reference_labels,
new_labels_shape)
new_embeddings_shape = tf.stack([batch_size,
num_frames_per_video, h, w,
embedding_dimension])
reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
all_nn_features = []
all_ref_obj_ids = []
# To keep things simple, we do all this separate for each sequence for now.
for n in range(batch_size):
embedding = reshaped_embeddings[n]
if ref_embeddings is None:
n_chunks = 100
reference_embedding = embedding[0]
if also_attend_to_previous_frame or use_softmax_feedback:
queries_embedding = embedding[2:]
else:
queries_embedding = embedding[1:]
else:
if USE_CORRELATION_COST:
n_chunks = 20
else:
n_chunks = 500
reference_embedding = ref_embeddings[0][n]
queries_embedding = embedding
reference_labels = reshaped_reference_labels[n][0]
nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
reference_embedding, queries_embedding, reference_labels,
max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
if normalize_nearest_neighbor_distances:
nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
all_nn_features.append(nn_features_n)
all_ref_obj_ids.append(ref_obj_ids)
feat_dim = resolve_shape(features)[-1]
features = tf.reshape(features, tf.stack(
[batch_size, num_frames_per_video, h, w, feat_dim]))
if ref_embeddings is None:
# Strip the features for the reference frame.
if also_attend_to_previous_frame or use_softmax_feedback:
features = features[:, 2:]
else:
features = features[:, 1:]
# To keep things simple, we do all this separate for each sequence for now.
outputs_to_logits = {output: [] for
output in model_options.outputs_to_num_classes}
for n in range(batch_size):
features_n = features[n]
nn_features_n = all_nn_features[n]
nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
n_objs = tf.shape(nn_features_n_tr)[0]
# Repeat features for every object.
features_n_tiled = tf.tile(features_n[tf.newaxis],
multiples=[n_objs, 1, 1, 1, 1])
prev_frame_labels = None
if also_attend_to_previous_frame:
prev_frame_labels = reshaped_reference_labels[n, 1]
if is_training and damage_initial_previous_frame_mask:
# Damage the previous frame masks.
prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
dilate=False)
tf.summary.image('prev_frame_labels',
tf.cast(prev_frame_labels[tf.newaxis],
tf.uint8) * 32)
initial_softmax_feedback_n = create_initial_softmax_from_labels(
prev_frame_labels, reshaped_reference_labels[n][0],
decoder_output_stride=None, reduce_labels=True)
elif initial_softmax_feedback is not None:
initial_softmax_feedback_n = initial_softmax_feedback[n]
else:
initial_softmax_feedback_n = None
if initial_softmax_feedback_n is None:
last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
else:
last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
..., tf.newaxis]
assert len(model_options.outputs_to_num_classes) == 1
output = model_options.outputs_to_num_classes.keys()[0]
logits = []
n_ref_frames = 1
prev_frame_nn_features_n = None
if also_attend_to_previous_frame or use_softmax_feedback:
n_ref_frames += 1
if ref_embeddings is not None:
n_ref_frames = 0
for t in range(num_frames_per_video - n_ref_frames):
to_concat = [features_n_tiled[:, t]]
if use_first_frame_matching:
to_concat.append(nn_features_n_tr[:, t])
if use_softmax_feedback:
to_concat.append(last_softmax)
if also_attend_to_previous_frame:
assert normalize_nearest_neighbor_distances, (
'previous frame attention currently only works when normalized '
'distances are used')
embedding = reshaped_embeddings[n]
if ref_embeddings is None:
last_frame_embedding = embedding[t + 1]
query_embeddings = embedding[t + 2, tf.newaxis]
else:
last_frame_embedding = ref_embeddings[1][0]
query_embeddings = embedding
if use_local_previous_frame_attention:
assert query_embeddings.shape[0] == 1
prev_frame_nn_features_n = (
local_previous_frame_nearest_neighbor_features_per_object(
last_frame_embedding,
query_embeddings[0],
prev_frame_labels,
all_ref_obj_ids[n],
max_distance=previous_frame_attention_window_size)
)
else:
prev_frame_nn_features_n, _ = (
nearest_neighbor_features_per_object(
last_frame_embedding, query_embeddings, prev_frame_labels,
max_neighbors_per_object, k_nearest_neighbors,
gt_ids=all_ref_obj_ids[n]))
prev_frame_nn_features_n = (tf.nn.sigmoid(
prev_frame_nn_features_n) - 0.5) * 2
prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
axis=0)
prev_frame_nn_features_n_tr = tf.transpose(
prev_frame_nn_features_n_sq, [2, 0, 1, 3])
to_concat.append(prev_frame_nn_features_n_tr)
features_n_concat_t = tf.concat(to_concat, axis=-1)
embedding_seg_features_n_t = (
create_embedding_segmentation_features(
features_n_concat_t, embedding_seg_feature_dimension,
embedding_seg_n_layers, embedding_seg_kernel_size,
reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
logits_t = model.get_branch_logits(
embedding_seg_features_n_t,
1,
model_options.atrous_rates,
aspp_with_batch_norm=model_options.aspp_with_batch_norm,
kernel_size=model_options.logits_kernel_size,
weight_decay=weight_decay,
reuse=reuse or n > 0 or t > 0,
scope_suffix=output
)
logits.append(logits_t)
prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
[2, 0, 1])
last_softmax = tf.nn.softmax(logits_t, axis=0)
logits = tf.stack(logits, axis=1)
logits_shape = tf.stack(
[n_objs, num_frames_per_video - n_ref_frames] +
resolve_shape(logits)[2:-1])
logits_reshaped = tf.reshape(logits, logits_shape)
logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
outputs_to_logits[output].append(logits_transposed)
add_image_summaries(
images[n * num_frames_per_video: (n+1) * num_frames_per_video],
nn_features_n,
logits_transposed,
batch_size=1,
prev_frame_nn_features=prev_frame_nn_features_n)
if also_return_embeddings:
return outputs_to_logits, embeddings
else:
return outputs_to_logits
def subsample_reference_embeddings_and_labels(
reference_embeddings_flat, reference_labels_flat, ref_obj_ids,
max_neighbors_per_object):
"""Subsamples the reference embedding vectors and labels.
After subsampling, at most max_neighbors_per_object items will remain per
class.
Args:
reference_embeddings_flat: Tensor of shape [n, embedding_dim],
the embedding vectors for the reference frame.
reference_labels_flat: Tensor of shape [n, 1],
the class labels of the reference frame.
ref_obj_ids: An int32 tensor of the unique object ids present
in the reference labels.
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling,
or 0 for no subsampling.
Returns:
reference_embeddings_flat: Tensor of shape [n_sub, embedding_dim],
the subsampled embedding vectors for the reference frame.
reference_labels_flat: Tensor of shape [n_sub, 1],
the class labels of the reference frame.
"""
if max_neighbors_per_object == 0:
return reference_embeddings_flat, reference_labels_flat
same_label_mask = tf.equal(reference_labels_flat[tf.newaxis, :],
ref_obj_ids[:, tf.newaxis])
max_neighbors_per_object_repeated = tf.tile(
tf.constant(max_neighbors_per_object)[tf.newaxis],
multiples=[tf.size(ref_obj_ids)])
# Somehow map_fn on GPU caused trouble sometimes, so let's put it on CPU
# for now.
with tf.device('cpu:0'):
subsampled_indices = tf.map_fn(_create_subsampling_mask,
(same_label_mask,
max_neighbors_per_object_repeated),
dtype=tf.int64,
name='subsample_labels_map_fn',
parallel_iterations=1)
mask = tf.not_equal(subsampled_indices, tf.constant(-1, dtype=tf.int64))
masked_indices = tf.boolean_mask(subsampled_indices, mask)
reference_embeddings_flat = tf.gather(reference_embeddings_flat,
masked_indices)
reference_labels_flat = tf.gather(reference_labels_flat, masked_indices)
return reference_embeddings_flat, reference_labels_flat
def _create_subsampling_mask(args):
"""Creates boolean mask which can be used to subsample the labels.
Args:
args: tuple of (label_mask, max_neighbors_per_object), where label_mask
is the mask to be subsampled and max_neighbors_per_object is a int scalar,
the maximum number of neighbors to be retained after subsampling.
Returns:
The boolean mask for subsampling the labels.
"""
label_mask, max_neighbors_per_object = args
indices = tf.squeeze(tf.where(label_mask), axis=1)
shuffled_indices = tf.random_shuffle(indices)
subsampled_indices = shuffled_indices[:max_neighbors_per_object]
n_pad = max_neighbors_per_object - tf.size(subsampled_indices)
padded_label = -1
padding = tf.fill((n_pad,), tf.constant(padded_label, dtype=tf.int64))
padded = tf.concat([subsampled_indices, padding], axis=0)
return padded
def conv2d_identity_initializer(scale=1.0, mean=0, stddev=3e-2):
"""Creates an identity initializer for TensorFlow conv2d.
We add a small amount of normal noise to the initialization matrix.
Code copied from lcchen@.
Args:
scale: The scale coefficient for the identity weight matrix.
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
truncated normal distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
of the truncated normal distribution.
Returns:
An identity initializer function for TensorFlow conv2d.
"""
def _initializer(shape, dtype=tf.float32, partition_info=None):
"""Returns the identity matrix scaled by `scale`.
Args:
shape: A tuple of int32 numbers indicating the shape of the initializing
matrix.
dtype: The data type of the initializing matrix.
partition_info: (Optional) variable_scope._PartitionInfo object holding
additional information about how the variable is partitioned. This input
is not used in our case, but is required by TensorFlow.
Returns:
A identity matrix.
Raises:
ValueError: If len(shape) != 4, or shape[0] != shape[1], or shape[0] is
not odd, or shape[1] is not odd..
"""
del partition_info
if len(shape) != 4:
raise ValueError('Expect shape length to be 4.')
if shape[0] != shape[1]:
raise ValueError('Expect shape[0] = shape[1].')
if shape[0] % 2 != 1:
raise ValueError('Expect shape[0] to be odd value.')
if shape[1] % 2 != 1:
raise ValueError('Expect shape[1] to be odd value.')
weights = np.zeros(shape, dtype=np.float32)
center_y = shape[0] / 2
center_x = shape[1] / 2
min_channel = min(shape[2], shape[3])
for i in range(min_channel):
weights[center_y, center_x, i, i] = scale
return tf.constant(weights, dtype=dtype) + tf.truncated_normal(
shape, mean=mean, stddev=stddev, dtype=dtype)
return _initializer
def split_separable_conv2d_with_identity_initializer(
inputs,
filters,
kernel_size=3,
rate=1,
weight_decay=0.00004,
scope=None):
"""Splits a separable conv2d into depthwise and pointwise conv2d.
This operation differs from `tf.layers.separable_conv2d` as this operation
applies activation function between depthwise and pointwise conv2d.
Args:
inputs: Input tensor with shape [batch, height, width, channels].
filters: Number of filters in the 1x1 pointwise convolution.
kernel_size: A list of length 2: [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
rate: Atrous convolution rate for the depthwise convolution.
weight_decay: The weight decay to use for regularizing the model.
scope: Optional scope for the operation.
Returns:
Computed features after split separable conv2d.
"""
initializer = conv2d_identity_initializer()
outputs = slim.separable_conv2d(
inputs,
None,
kernel_size=kernel_size,
depth_multiplier=1,
rate=rate,
weights_initializer=initializer,
weights_regularizer=None,
scope=scope + '_depthwise')
return slim.conv2d(
outputs,
filters,
1,
weights_initializer=initializer,
weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise')
def create_initial_softmax_from_labels(last_frame_labels, reference_labels,
decoder_output_stride, reduce_labels):
"""Creates initial softmax predictions from last frame labels.
Args:
last_frame_labels: last frame labels of shape [1, height, width, 1].
reference_labels: reference frame labels of shape [1, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder. Can be None, in
this case it's assumed that the last_frame_labels and reference_labels
are already scaled to the decoder output resolution.
reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot
encoding to the actual number of labels present in the reference frame
(otherwise the depth will be the highest label index + 1).
Returns:
init_softmax: the initial softmax predictions.
"""
if decoder_output_stride is None:
labels_output_size = last_frame_labels
reference_labels_output_size = reference_labels
else:
h = tf.shape(last_frame_labels)[1]
w = tf.shape(last_frame_labels)[2]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
labels_output_size = tf.image.resize_nearest_neighbor(
last_frame_labels, [h_sub, w_sub], align_corners=True)
reference_labels_output_size = tf.image.resize_nearest_neighbor(
reference_labels, [h_sub, w_sub], align_corners=True)
if reduce_labels:
unique_labels, _ = tf.unique(tf.reshape(reference_labels_output_size, [-1]))
depth = tf.size(unique_labels)
else:
depth = tf.reduce_max(reference_labels_output_size) + 1
one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size), depth)
with tf.control_dependencies([one_hot_assertion]):
init_softmax = tf.one_hot(tf.squeeze(labels_output_size,
axis=-1),
depth=depth,
dtype=tf.float32)
return init_softmax
def local_previous_frame_nearest_neighbor_features_per_object(
prev_frame_embedding, query_embedding, prev_frame_labels,
gt_ids, max_distance=9):
"""Computes nearest neighbor features while only allowing local matches.
Args:
prev_frame_embedding: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the last frame.
query_embedding: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the query frames.
prev_frame_labels: Tensor of shape [height, width, 1], the class labels of
the previous frame.
gt_ids: Int Tensor of shape [n_objs] of the sorted unique ground truth
ids in the first frame.
max_distance: Integer, the maximum distance allowed for local matching.
Returns:
nn_features: A float32 np.array of nearest neighbor features of shape
[1, height, width, n_objects, 1].
"""
with tf.name_scope(
'local_previous_frame_nearest_neighbor_features_per_object'):
if USE_CORRELATION_COST:
tf.logging.info('Using correlation_cost.')
d = local_pairwise_distances(query_embedding, prev_frame_embedding,
max_distance=max_distance)
else:
# Slow fallback in case correlation_cost is not available.
tf.logging.warn('correlation cost is not available, using slow fallback '
'implementation.')
d = local_pairwise_distances2(query_embedding, prev_frame_embedding,
max_distance=max_distance)
d = (tf.nn.sigmoid(d) - 0.5) * 2
height = tf.shape(prev_frame_embedding)[0]
width = tf.shape(prev_frame_embedding)[1]
# Create offset versions of the mask.
if USE_CORRELATION_COST:
# New, faster code with cross-correlation via correlation_cost.
# Due to padding we have to add 1 to the labels.
offset_labels = correlation_cost_op.correlation_cost(
tf.ones((1, height, width, 1)),
tf.cast(prev_frame_labels + 1, tf.float32)[tf.newaxis],
kernel_size=1,
max_displacement=max_distance, stride_1=1, stride_2=1,
pad=max_distance)
offset_labels = tf.squeeze(offset_labels, axis=0)[..., tf.newaxis]
# Subtract the 1 again and round.
offset_labels = tf.round(offset_labels - 1)
offset_masks = tf.equal(
offset_labels,
tf.cast(gt_ids, tf.float32)[tf.newaxis, tf.newaxis, tf.newaxis, :])
else:
# Slower code, without dependency to correlation_cost
masks = tf.equal(prev_frame_labels, gt_ids[tf.newaxis, tf.newaxis])
padded_masks = tf.pad(masks,
[[max_distance, max_distance],
[max_distance, max_distance],
[0, 0]])
offset_masks = []
for y_start in range(2 * max_distance + 1):
y_end = y_start + height
masks_slice = padded_masks[y_start:y_end]
for x_start in range(2 * max_distance + 1):
x_end = x_start + width
offset_mask = masks_slice[:, x_start:x_end]
offset_masks.append(offset_mask)
offset_masks = tf.stack(offset_masks, axis=2)
pad = tf.ones((height, width, (2 * max_distance + 1) ** 2, tf.size(gt_ids)))
d_tiled = tf.tile(d[..., tf.newaxis], multiples=(1, 1, 1, tf.size(gt_ids)))
d_masked = tf.where(offset_masks, d_tiled, pad)
dists = tf.reduce_min(d_masked, axis=2)
dists = tf.reshape(dists, (1, height, width, tf.size(gt_ids), 1))
return dists
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for embedding utils."""
import unittest
import numpy as np
import tensorflow as tf
from feelvos.utils import embedding_utils
if embedding_utils.USE_CORRELATION_COST:
# pylint: disable=g-import-not-at-top
from correlation_cost.python.ops import correlation_cost_op
class EmbeddingUtilsTest(tf.test.TestCase):
def test_pairwise_distances(self):
x = np.arange(100, dtype=np.float32).reshape(20, 5)
y = np.arange(100, 200, dtype=np.float32).reshape(20, 5)
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
x = tf.constant(x)
y = tf.constant(y)
d1 = embedding_utils.pairwise_distances(x, y)
d2 = embedding_utils.pairwise_distances2(x, y)
d1_val, d2_val = sess.run([d1, d2])
self.assertAllClose(d1_val, d2_val)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_correlation_cost_one_dimensional(self):
a = np.array([[[[1.0], [2.0]], [[3.0], [4.0]]]])
b = np.array([[[[2.0], [1.0]], [[4.0], [3.0]]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
c = correlation_cost_op.correlation_cost(
a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
pad=1)
c = tf.squeeze(c, axis=0)
c_val = sess.run(c)
self.assertAllEqual(c_val.shape, (2, 2, 9))
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[0, y, x, 0]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
b_slice = 0
else:
b_slice = b[0, y + dy, x + dx, 0]
expected = a_slice * b_slice
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_correlation_cost_two_dimensional(self):
a = np.array([[[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]]])
b = np.array([[[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
c = correlation_cost_op.correlation_cost(
a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
pad=1)
c = tf.squeeze(c, axis=0)
c_val = sess.run(c)
self.assertAllEqual(c_val.shape, (2, 2, 9))
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[0, y, x, :]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
b_slice = 0
else:
b_slice = b[0, y + dy, x + dx, :]
expected = (a_slice * b_slice).mean()
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_one_dimensional(self):
a = np.array([[[1.0], [2.0]], [[3.0], [4.0]]])
b = np.array([[[2.0], [1.0]], [[4.0], [3.0]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
max_distance=1)
d_val = sess.run(d)
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[y, x, 0]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
expected = np.float('inf')
else:
b_slice = b[y + dy, x + dx, 0]
expected = (a_slice - b_slice) ** 2
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_shape(self):
a = np.zeros((4, 5, 2))
b = np.zeros((4, 5, 2))
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf, max_distance=4)
d_val = sess.run(d)
self.assertAllEqual(d_val.shape, (4, 5, 81))
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_two_dimensional(self):
a = np.array([[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]])
b = np.array([[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
max_distance=1)
d_val = sess.run(d)
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[y, x, :]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
expected = np.float('inf')
else:
b_slice = b[y + dy, x + dx, :]
expected = ((a_slice - b_slice) ** 2).sum()
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_previous_frame_nearest_neighbor_features_per_object(self):
prev_frame_embedding = np.array([[[1.0, -5.0], [7.0, 2.0]],
[[1.0, 3.0], [3.0, 4.0]]]) / 10
query_embedding = np.array([[[2.0, 1.0], [0.0, -9.0]],
[[4.0, 3.0], [3.0, 1.0]]]) / 10
prev_frame_labels = np.array([[[0], [1]], [[1], [0]]])
gt_ids = np.array([0, 1])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
prev_frame_embedding_tf = tf.constant(prev_frame_embedding,
dtype=tf.float32)
query_embedding_tf = tf.constant(query_embedding, dtype=tf.float32)
embu = embedding_utils
dists = (
embu.local_previous_frame_nearest_neighbor_features_per_object(
prev_frame_embedding_tf, query_embedding_tf,
prev_frame_labels, gt_ids, max_distance=1))
dists = tf.squeeze(dists, axis=4)
dists = tf.squeeze(dists, axis=0)
dists_val = sess.run(dists)
for obj_id in gt_ids:
for y in range(2):
for x in range(2):
curr_min = 1.0
for dy in range(-1, 2):
for dx in range(-1, 2):
# Attention: here we shift the prev frame embedding,
# not the query.
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
continue
if prev_frame_labels[y + dy, x + dx, 0] != obj_id:
continue
prev_frame_slice = prev_frame_embedding[y + dy, x + dx, :]
query_frame_slice = query_embedding[y, x, :]
v_unnorm = ((prev_frame_slice - query_frame_slice) ** 2).sum()
v = ((1.0 / (1.0 + np.exp(-v_unnorm))) - 0.5) * 2
curr_min = min(curr_min, v)
expected = curr_min
self.assertAlmostEqual(dists_val[y, x, obj_id], expected,
places=5)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for evaluations."""
import numpy as np
import PIL
import tensorflow as tf
pascal_colormap = [
0, 0, 0,
0.5020, 0, 0,
0, 0.5020, 0,
0.5020, 0.5020, 0,
0, 0, 0.5020,
0.5020, 0, 0.5020,
0, 0.5020, 0.5020,
0.5020, 0.5020, 0.5020,
0.2510, 0, 0,
0.7529, 0, 0,
0.2510, 0.5020, 0,
0.7529, 0.5020, 0,
0.2510, 0, 0.5020,
0.7529, 0, 0.5020,
0.2510, 0.5020, 0.5020,
0.7529, 0.5020, 0.5020,
0, 0.2510, 0,
0.5020, 0.2510, 0,
0, 0.7529, 0,
0.5020, 0.7529, 0,
0, 0.2510, 0.5020,
0.5020, 0.2510, 0.5020,
0, 0.7529, 0.5020,
0.5020, 0.7529, 0.5020,
0.2510, 0.2510, 0]
def save_segmentation_with_colormap(filename, img):
"""Saves a segmentation with the pascal colormap as expected for DAVIS eval.
Args:
filename: Where to store the segmentation.
img: A numpy array of the segmentation to be saved.
"""
if img.shape[-1] == 1:
img = img[..., 0]
# Save with colormap.
colormap = (np.array(pascal_colormap) * 255).round().astype('uint8')
colormap_image = PIL.Image.new('P', (16, 16))
colormap_image.putpalette(colormap)
pil_image = PIL.Image.fromarray(img.astype('uint8'))
pil_image_with_colormap = pil_image.quantize(palette=colormap_image)
with tf.gfile.GFile(filename, 'w') as f:
pil_image_with_colormap.save(f)
def save_embeddings(filename, embeddings):
with tf.gfile.GFile(filename, 'w') as f:
np.save(f, embeddings)
def calculate_iou(pred_labels, ref_labels):
"""Calculates the intersection over union for binary segmentation.
Args:
pred_labels: predicted segmentation labels.
ref_labels: reference segmentation labels.
Returns:
The IoU between pred_labels and ref_labels
"""
if ref_labels.any():
i = np.logical_and(pred_labels, ref_labels).sum()
u = np.logical_or(pred_labels, ref_labels).sum()
return i.astype('float') / u
else:
if pred_labels.any():
return 0.0
else:
return 1.0
def calculate_multi_object_miou_tf(pred_labels, ref_labels):
"""Calculates the mIoU for a batch of predicted and reference labels.
Args:
pred_labels: Int32 tensor of shape [batch, height, width, 1].
ref_labels: Int32 tensor of shape [batch, height, width, 1].
Returns:
The mIoU between pred_labels and ref_labels as float32 scalar tensor.
"""
def calculate_multi_object_miou(pred_labels_, ref_labels_):
"""Calculates the mIoU for predicted and reference labels in numpy.
Args:
pred_labels_: int32 np.array of shape [batch, height, width, 1].
ref_labels_: int32 np.array of shape [batch, height, width, 1].
Returns:
The mIoU between pred_labels_ and ref_labels_.
"""
assert len(pred_labels_.shape) == 4
assert pred_labels_.shape[3] == 1
assert pred_labels_.shape == ref_labels_.shape
ious = []
for pred_label, ref_label in zip(pred_labels_, ref_labels_):
ids = np.setdiff1d(np.unique(ref_label), [0])
if ids.size == 0:
continue
for id_ in ids:
iou = calculate_iou(pred_label == id_, ref_label == id_)
ious.append(iou)
if ious:
return np.cast['float32'](np.mean(ious))
else:
return np.cast['float32'](1.0)
miou = tf.py_func(calculate_multi_object_miou, [pred_labels, ref_labels],
tf.float32, name='calculate_multi_object_miou')
miou.set_shape(())
return miou
def calculate_multi_object_ious(pred_labels, ref_labels, label_set):
"""Calculates the intersection over union for binary segmentation.
Args:
pred_labels: predicted segmentation labels.
ref_labels: reference segmentation labels.
label_set: int np.array of object ids.
Returns:
float np.array of IoUs between pred_labels and ref_labels
for each object in label_set.
"""
# Background should not be included as object label.
return np.array([calculate_iou(pred_labels == label, ref_labels == label)
for label in label_set if label != 0])
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for artificially damaging segmentation masks."""
import numpy as np
from scipy.ndimage import interpolation
from skimage import morphology
from skimage import transform
import tensorflow as tf
def damage_masks(labels, shift=True, scale=True, rotate=True, dilate=True):
"""Damages segmentation masks by random transformations.
Args:
labels: Int32 labels tensor of shape (height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of labels.
"""
def _damage_masks_np(labels_):
return damage_masks_np(labels_, shift, scale, rotate, dilate)
damaged_masks = tf.py_func(_damage_masks_np, [labels], tf.int32,
name='damage_masks')
damaged_masks.set_shape(labels.get_shape())
return damaged_masks
def damage_masks_np(labels, shift=True, scale=True, rotate=True, dilate=True):
"""Performs the actual mask damaging in numpy.
Args:
labels: Int32 numpy array of shape (height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of labels.
"""
unique_labels = np.unique(labels)
unique_labels = np.setdiff1d(unique_labels, [0])
# Shuffle to get random depth ordering when combining together.
np.random.shuffle(unique_labels)
damaged_labels = np.zeros_like(labels)
for l in unique_labels:
obj_mask = (labels == l)
damaged_obj_mask = _damage_single_object_mask(obj_mask, shift, scale,
rotate, dilate)
damaged_labels[damaged_obj_mask] = l
return damaged_labels
def _damage_single_object_mask(mask, shift, scale, rotate, dilate):
"""Performs mask damaging in numpy for a single object.
Args:
mask: Boolean numpy array of shape(height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of mask.
"""
# For now we just do shifting and scaling. Better would be Affine or thin
# spline plate transformations.
if shift:
mask = _shift_mask(mask)
if scale:
mask = _scale_mask(mask)
if rotate:
mask = _rotate_mask(mask)
if dilate:
mask = _dilate_mask(mask)
return mask
def _shift_mask(mask, max_shift_factor=0.05):
"""Damages a mask for a single object by randomly shifting it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
max_shift_factor: Float scalar, the maximum factor for random shifting.
Returns:
The shifted version of mask.
"""
nzy, nzx, _ = mask.nonzero()
h = nzy.max() - nzy.min()
w = nzx.max() - nzx.min()
size = np.sqrt(h * w)
offset = np.random.uniform(-size * max_shift_factor, size * max_shift_factor,
2)
shifted_mask = interpolation.shift(np.squeeze(mask, axis=2),
offset, order=0).astype('bool')[...,
np.newaxis]
return shifted_mask
def _scale_mask(mask, scale_amount=0.025):
"""Damages a mask for a single object by randomly scaling it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
scale_amount: Float scalar, the maximum factor for random scaling.
Returns:
The scaled version of mask.
"""
nzy, nzx, _ = mask.nonzero()
cy = 0.5 * (nzy.max() - nzy.min())
cx = 0.5 * (nzx.max() - nzx.min())
scale_factor = np.random.uniform(1.0 - scale_amount, 1.0 + scale_amount)
shift = transform.SimilarityTransform(translation=[-cx, -cy])
inv_shift = transform.SimilarityTransform(translation=[cx, cy])
s = transform.SimilarityTransform(scale=[scale_factor, scale_factor])
m = (shift + (s + inv_shift)).inverse
scaled_mask = transform.warp(mask, m) > 0.5
return scaled_mask
def _rotate_mask(mask, max_rot_degrees=3.0):
"""Damages a mask for a single object by randomly rotating it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
max_rot_degrees: Float scalar, the maximum number of degrees to rotate.
Returns:
The scaled version of mask.
"""
cy = 0.5 * mask.shape[0]
cx = 0.5 * mask.shape[1]
rot_degrees = np.random.uniform(-max_rot_degrees, max_rot_degrees)
shift = transform.SimilarityTransform(translation=[-cx, -cy])
inv_shift = transform.SimilarityTransform(translation=[cx, cy])
r = transform.SimilarityTransform(rotation=np.deg2rad(rot_degrees))
m = (shift + (r + inv_shift)).inverse
scaled_mask = transform.warp(mask, m) > 0.5
return scaled_mask
def _dilate_mask(mask, dilation_radius=5):
"""Damages a mask for a single object by dilating it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
dilation_radius: Integer, the radius of the used disk structure element.
Returns:
The dilated version of mask.
"""
disk = morphology.disk(dilation_radius, dtype=np.bool)
dilated_mask = morphology.binary_dilation(
np.squeeze(mask, axis=2), selem=disk)[..., np.newaxis]
return dilated_mask
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for training."""
import collections
import six
import tensorflow as tf
from deeplab.core import preprocess_utils
from deeplab.utils import train_utils
from feelvos.utils import embedding_utils
from feelvos.utils import eval_utils
slim = tf.contrib.slim
add_softmax_cross_entropy_loss_for_each_scale = (
train_utils.add_softmax_cross_entropy_loss_for_each_scale)
get_model_gradient_multipliers = train_utils.get_model_gradient_multipliers
get_model_learning_rate = train_utils.get_model_learning_rate
resolve_shape = preprocess_utils.resolve_shape
def add_triplet_loss_for_each_scale(batch_size, num_frames_per_video,
embedding_dim, scales_to_embeddings,
labels, scope):
"""Adds triplet loss for logits of each scale.
Args:
batch_size: Int, the number of video chunks sampled per batch
num_frames_per_video: Int, the number of frames per video.
embedding_dim: Int, the dimension of the learned embedding
scales_to_embeddings: A map from embedding names for different scales to
embeddings. The embeddings have shape [batch, embeddings_height,
embeddings_width, embedding_dim].
labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
scope: String, the scope for the loss.
Raises:
ValueError: labels is None.
"""
if labels is None:
raise ValueError('No label for triplet loss.')
for scale, embeddings in scales_to_embeddings.iteritems():
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
# Label is downsampled to the same size as logits.
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
resolve_shape(embeddings, 4)[1:3],
align_corners=True)
# Reshape from [batch * num_frames, ...] to [batch, num_frames, ...].
h = tf.shape(embeddings)[1]
w = tf.shape(embeddings)[2]
new_labels_shape = tf.stack([batch_size, num_frames_per_video, h, w, 1])
reshaped_labels = tf.reshape(scaled_labels, new_labels_shape)
new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w,
-1])
reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
with tf.name_scope(loss_scope):
total_loss = tf.constant(0, dtype=tf.float32)
for n in range(batch_size):
embedding = reshaped_embeddings[n]
label = reshaped_labels[n]
n_pixels = h * w
n_anchors_used = 256
sampled_anchor_indices = tf.random_shuffle(tf.range(n_pixels))[
:n_anchors_used]
anchors_pool = tf.reshape(embedding[0], [-1, embedding_dim])
anchors_pool_classes = tf.reshape(label[0], [-1])
anchors = tf.gather(anchors_pool, sampled_anchor_indices)
anchor_classes = tf.gather(anchors_pool_classes, sampled_anchor_indices)
pos_neg_pool = tf.reshape(embedding[1:], [-1, embedding_dim])
pos_neg_pool_classes = tf.reshape(label[1:], [-1])
dists = embedding_utils.pairwise_distances(anchors, pos_neg_pool)
pos_mask = tf.equal(anchor_classes[:, tf.newaxis],
pos_neg_pool_classes[tf.newaxis, :])
neg_mask = tf.logical_not(pos_mask)
pos_mask_f = tf.cast(pos_mask, tf.float32)
neg_mask_f = tf.cast(neg_mask, tf.float32)
pos_dists = pos_mask_f * dists + 1e20 * neg_mask_f
neg_dists = neg_mask_f * dists + 1e20 * pos_mask_f
pos_dists_min = tf.reduce_min(pos_dists, axis=1)
neg_dists_min = tf.reduce_min(neg_dists, axis=1)
margin = 1.0
loss = tf.nn.relu(pos_dists_min - neg_dists_min + margin)
# Handle case that no positive is present (per anchor).
any_pos = tf.reduce_any(pos_mask, axis=1)
loss *= tf.cast(any_pos, tf.float32)
# Average over anchors
loss = tf.reduce_mean(loss, axis=0)
total_loss += loss
total_loss /= batch_size
# Scale the loss up a bit.
total_loss *= 3.0
tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
def add_dynamic_softmax_cross_entropy_loss_for_each_scale(
scales_to_logits, labels, ignore_label, loss_weight=1.0,
upsample_logits=True, scope=None, top_k_percent_pixels=1.0,
hard_example_mining_step=100000):
"""Adds softmax cross entropy loss per scale for logits with varying classes.
Also adds summaries for mIoU.
Args:
scales_to_logits: A map from logits names for different scales to logits.
The logits are a list of length batch_size of tensors of shape
[time, logits_height, logits_width, num_classes].
labels: Groundtruth labels with shape [batch_size * time, image_height,
image_width, 1].
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not.
scope: String, the scope for the loss.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1.0, only compute the loss for the top k percent pixels (e.g.,
the top 20% pixels). This is useful for hard pixel mining.
hard_example_mining_step: An integer, the training step in which the
hard exampling mining kicks off. Note that we gradually reduce the
mining percent to the top_k_percent_pixels. For example, if
hard_example_mining_step=100K and top_k_percent_pixels=0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
Raises:
ValueError: Label or logits is None.
"""
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
if top_k_percent_pixels < 0 or top_k_percent_pixels > 1:
raise ValueError('Unexpected value of top_k_percent_pixels.')
for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
if upsample_logits:
# Label is not downsampled, and instead we upsample logits.
assert isinstance(logits, collections.Sequence)
logits = [tf.image.resize_bilinear(
x,
preprocess_utils.resolve_shape(labels, 4)[1:3],
align_corners=True) for x in logits]
scaled_labels = labels
else:
# Label is downsampled to the same size as logits.
assert isinstance(logits, collections.Sequence)
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
preprocess_utils.resolve_shape(logits[0], 4)[1:3],
align_corners=True)
batch_size = len(logits)
num_time = preprocess_utils.resolve_shape(logits[0])[0]
reshaped_labels = tf.reshape(
scaled_labels, ([batch_size, num_time] +
preprocess_utils.resolve_shape(scaled_labels)[1:]))
for n, logits_n in enumerate(logits):
labels_n = reshaped_labels[n]
labels_n = tf.reshape(labels_n, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(labels_n,
ignore_label)) * loss_weight
num_classes_n = tf.shape(logits_n)[-1]
one_hot_labels = slim.one_hot_encoding(
labels_n, num_classes_n, on_value=1.0, off_value=0.0)
logits_n_flat = tf.reshape(logits_n, shape=[-1, num_classes_n])
if top_k_percent_pixels == 1.0:
tf.losses.softmax_cross_entropy(
one_hot_labels,
logits_n_flat,
weights=not_ignore_mask,
scope=loss_scope)
else:
# Only compute the loss for top k percent pixels.
# First, compute the loss for all pixels. Note we do not put the loss
# to loss_collection and set reduction = None to keep the shape.
num_pixels = tf.to_float(tf.shape(logits_n_flat)[0])
pixel_losses = tf.losses.softmax_cross_entropy(
one_hot_labels,
logits_n_flat,
weights=not_ignore_mask,
scope='pixel_losses',
loss_collection=None,
reduction=tf.losses.Reduction.NONE)
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
_, top_k_indices = tf.nn.top_k(pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
# Compute the loss for the top k percent pixels.
tf.losses.softmax_cross_entropy(
tf.gather(one_hot_labels, top_k_indices),
tf.gather(logits_n_flat, top_k_indices),
weights=tf.gather(not_ignore_mask, top_k_indices),
scope=loss_scope)
pred_n = tf.argmax(logits_n, axis=-1, output_type=tf.int32)[
..., tf.newaxis]
labels_n = labels[n * num_time: (n + 1) * num_time]
miou = eval_utils.calculate_multi_object_miou_tf(pred_n, labels_n)
tf.summary.scalar('miou', miou)
def get_model_init_fn(train_logdir,
tf_initial_checkpoint,
initialize_last_layer,
last_layers,
ignore_missing_vars=False):
"""Gets the function initializing model variables from a checkpoint.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
if variables_to_restore:
return slim.assign_from_checkpoint_fn(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
return None
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation video data."""
import tensorflow as tf
from feelvos import input_preprocess
from feelvos import model
from feelvos.utils import mask_damaging
from feelvos.utils import train_utils
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
MIN_LABEL_COUNT = 10
def decode_image_sequence(tensor, image_format='jpeg', shape=None,
channels=3, raw_dtype=tf.uint8):
"""Decodes a sequence of images.
Args:
tensor: the tensor of strings to decode, shape: [num_images]
image_format: a string (possibly tensor) with the format of the image.
Options include 'jpeg', 'png', and 'raw'.
shape: a list or tensor of the decoded image shape for a single image.
channels: if 'shape' is None, the third dimension of the image is set to
this value.
raw_dtype: if the image is encoded as raw bytes, this is the method of
decoding the bytes into values.
Returns:
The decoded images with shape [time, height, width, channels].
"""
handler = slim.tfexample_decoder.Image(
shape=shape, channels=channels, dtype=raw_dtype, repeated=True)
return handler.tensors_to_item({'image/encoded': tensor,
'image/format': image_format})
def _get_data(data_provider, dataset_split, video_frames_are_decoded):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
video_frames_are_decoded: Boolean, whether the video frames are already
decoded
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
object_label: An integer refers to object_label according to labelmap. If
the example has more than one object_label, take the first one.
image_name: Image name.
height: Image height.
width: Image width.
video_id: String tensor representing the name of the video.
Raises:
ValueError: Failed to find label.
"""
if video_frames_are_decoded:
image, = data_provider.get(['image'])
else:
image, = data_provider.get(['image/encoded'])
# Some datasets do not contain image_name.
if 'image_name' in data_provider.list_items():
image_name, = data_provider.get(['image_name'])
else:
image_name = tf.constant('')
height, width = data_provider.get(['height', 'width'])
label = None
if dataset_split != 'test':
if video_frames_are_decoded:
if 'labels_class' not in data_provider.list_items():
raise ValueError('Failed to find labels.')
label, = data_provider.get(['labels_class'])
else:
key = 'segmentation/object/encoded'
if key not in data_provider.list_items():
raise ValueError('Failed to find labels.')
label, = data_provider.get([key])
object_label = None
video_id, = data_provider.get(['video_id'])
return image, label, object_label, image_name, height, width, video_id
def _has_foreground_and_background_in_first_frame(label, subsampling_factor):
"""Checks if the labels have foreground and background in the first frame.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
subsampling_factor: Integer, the subsampling factor.
Returns:
Boolean, whether the labels have foreground and background in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis],
[h // subsampling_factor,
w // subsampling_factor],
align_corners=True),
axis=0)
is_bg = tf.equal(label_downscaled, 0)
is_fg = tf.logical_not(is_bg)
# Just using reduce_any was not robust enough, so lets make sure the count
# is above MIN_LABEL_COUNT.
fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
return tf.logical_and(has_bg, has_fg)
def _has_foreground_and_background_in_first_frame_2(label,
decoder_output_stride):
"""Checks if the labels have foreground and background in the first frame.
Second attempt, this time we use the actual output dimension for resizing.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder output.
Returns:
Boolean, whether the labels have foreground and background in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
align_corners=True), axis=0)
is_bg = tf.equal(label_downscaled, 0)
is_fg = tf.logical_not(is_bg)
# Just using reduce_any was not robust enough, so lets make sure the count
# is above MIN_LABEL_COUNT.
fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
return tf.logical_and(has_bg, has_fg)
def _has_enough_pixels_of_each_object_in_first_frame(
label, decoder_output_stride):
"""Checks if for each object (incl. background) enough pixels are visible.
During test time, we will usually not see a reference frame in which only
very few pixels of one object are visible. These cases can be problematic
during training, especially if more than the 1-nearest neighbor is used.
That's why this function can be used to detect and filter these cases.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder output.
Returns:
Boolean, whether the labels have enough pixels of each object in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
align_corners=True), axis=0)
_, _, counts = tf.unique_with_counts(
tf.reshape(label_downscaled, [-1]))
has_enough_pixels_per_object = tf.reduce_all(
tf.greater_equal(counts, MIN_LABEL_COUNT))
return has_enough_pixels_per_object
def get(dataset,
num_frames_per_video,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
preprocess_image_and_label=True,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None,
batch_capacity_factor=32,
video_frames_are_decoded=False,
decoder_output_stride=None,
first_frame_finetuning=False,
sample_only_first_frame_for_finetuning=False,
sample_adjacent_and_consistent_query_frames=False,
remap_labels_to_reference_frame=True,
generate_prev_frame_mask_by_mask_damaging=False,
three_frame_dataset=False,
add_prev_frame_label=True):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
num_frames_per_video: The number of frames used per video
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
preprocess_image_and_label: Boolean variable specifies if preprocessing of
image and label will be performed or not.
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
batch_capacity_factor: Batch capacity factor affecting the training queue
batch capacity.
video_frames_are_decoded: Boolean, whether the video frames are already
decoded
decoder_output_stride: Integer, the stride of the decoder output.
first_frame_finetuning: Boolean, whether to only sample the first frame
for fine-tuning.
sample_only_first_frame_for_finetuning: Boolean, whether to only sample the
first frame during fine-tuning. This should be False when using lucid or
wonderland data, but true when fine-tuning on the first frame only.
Only has an effect if first_frame_finetuning is True.
sample_adjacent_and_consistent_query_frames: Boolean, if true, the query
frames (all but the first frame which is the reference frame) will be
sampled such that they are adjacent video frames and have the same
crop coordinates and flip augmentation.
remap_labels_to_reference_frame: Boolean, whether to remap the labels of
the query frames to match the labels of the (downscaled) reference frame.
If a query frame contains a label which is not present in the reference,
it will be mapped to background.
generate_prev_frame_mask_by_mask_damaging: Boolean, whether to generate
the masks used as guidance from the previous frame by damaging the
ground truth mask.
three_frame_dataset: Boolean, whether the dataset has exactly three frames
per video of which the first is to be used as reference and the two
others are consecutive frames to be used as query frames.
add_prev_frame_label: Boolean, whether to sample one more frame before the
first query frame to obtain a previous frame label. Only has an effect,
if sample_adjacent_and_consistent_query_frames is True and
generate_prev_frame_mask_by_mask_damaging is False.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, or Failed to find labels.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, object_label, image_name, height, width, video_id = _get_data(
data_provider, dataset_split, video_frames_are_decoded)
sampling_is_valid = tf.constant(True)
if num_frames_per_video is not None:
total_num_frames = tf.shape(image)[0]
if first_frame_finetuning or three_frame_dataset:
if sample_only_first_frame_for_finetuning:
assert not sample_adjacent_and_consistent_query_frames, (
'this option does not make sense for sampling only first frame.')
# Sample the first frame num_frames_per_video times.
sel_indices = tf.tile(tf.constant(0, dtype=tf.int32)[tf.newaxis],
multiples=[num_frames_per_video])
else:
if sample_adjacent_and_consistent_query_frames:
if add_prev_frame_label:
num_frames_per_video += 1
# Since this is first frame fine-tuning, we'll for now assume that
# each sequence has exactly 3 images: the ref frame and 2 adjacent
# query frames.
assert num_frames_per_video == 3
with tf.control_dependencies([tf.assert_equal(total_num_frames, 3)]):
sel_indices = tf.constant([1, 2], dtype=tf.int32)
else:
# Sample num_frames_per_video - 1 query frames which are not the
# first frame.
sel_indices = tf.random_shuffle(
tf.range(1, total_num_frames))[:(num_frames_per_video - 1)]
# Concat first frame as reference frame to the front.
sel_indices = tf.concat([tf.constant(0, dtype=tf.int32)[tf.newaxis],
sel_indices], axis=0)
else:
if sample_adjacent_and_consistent_query_frames:
if add_prev_frame_label:
# Sample one more frame which we can use to provide initial softmax
# feedback.
num_frames_per_video += 1
ref_idx = tf.random_shuffle(tf.range(total_num_frames))[0]
sampling_is_valid = tf.greater_equal(total_num_frames,
num_frames_per_video)
def sample_query_start_idx():
return tf.random_shuffle(
tf.range(total_num_frames - num_frames_per_video + 1))[0]
query_start_idx = tf.cond(sampling_is_valid, sample_query_start_idx,
lambda: tf.constant(0, dtype=tf.int32))
def sample_sel_indices():
return tf.concat(
[ref_idx[tf.newaxis],
tf.range(
query_start_idx,
query_start_idx + (num_frames_per_video - 1))], axis=0)
sel_indices = tf.cond(
sampling_is_valid, sample_sel_indices,
lambda: tf.zeros((num_frames_per_video,), dtype=tf.int32))
else:
# Randomly sample some frames from the video.
sel_indices = tf.random_shuffle(
tf.range(total_num_frames))[:num_frames_per_video]
image = tf.gather(image, sel_indices, axis=0)
if not video_frames_are_decoded:
image = decode_image_sequence(image)
if label is not None:
if num_frames_per_video is not None:
label = tf.gather(label, sel_indices, axis=0)
if not video_frames_are_decoded:
label = decode_image_sequence(label, image_format='png', channels=1)
# Sometimes, label is saved as [num_frames_per_video, height, width] or
# [num_frames_per_video, height, width, 1]. We change it to be
# [num_frames_per_video, height, width, 1].
if label.shape.ndims == 3:
label = tf.expand_dims(label, 3)
elif label.shape.ndims == 4 and label.shape.dims[3] == 1:
pass
else:
raise ValueError('Input label shape must be '
'[num_frames_per_video, height, width],'
' or [num_frames, height, width, 1]. '
'Got {}'.format(label.shape.ndims))
label.set_shape([None, None, None, 1])
# Add size of first dimension since tf can't figure it out automatically.
image.set_shape((num_frames_per_video, None, None, None))
if label is not None:
label.set_shape((num_frames_per_video, None, None, None))
preceding_frame_label = None
if preprocess_image_and_label:
if num_frames_per_video is None:
raise ValueError('num_frame_per_video must be specified for preproc.')
original_images = []
images = []
labels = []
if sample_adjacent_and_consistent_query_frames:
num_frames_individual_preproc = 1
else:
num_frames_individual_preproc = num_frames_per_video
for frame_idx in range(num_frames_individual_preproc):
original_image_t, image_t, label_t = (
input_preprocess.preprocess_image_and_label(
image[frame_idx],
label[frame_idx],
crop_height=crop_size[0] if crop_size is not None else None,
crop_width=crop_size[1] if crop_size is not None else None,
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant))
original_images.append(original_image_t)
images.append(image_t)
labels.append(label_t)
if sample_adjacent_and_consistent_query_frames:
imgs_for_preproc = [image[frame_idx] for frame_idx in
range(1, num_frames_per_video)]
labels_for_preproc = [label[frame_idx] for frame_idx in
range(1, num_frames_per_video)]
original_image_rest, image_rest, label_rest = (
input_preprocess.preprocess_images_and_labels_consistently(
imgs_for_preproc,
labels_for_preproc,
crop_height=crop_size[0] if crop_size is not None else None,
crop_width=crop_size[1] if crop_size is not None else None,
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant))
original_images.extend(original_image_rest)
images.extend(image_rest)
labels.extend(label_rest)
assert len(original_images) == num_frames_per_video
assert len(images) == num_frames_per_video
assert len(labels) == num_frames_per_video
if remap_labels_to_reference_frame:
# Remap labels to indices into the labels of the (downscaled) reference
# frame, or 0, i.e. background, for labels which are not present
# in the reference.
reference_labels = labels[0][tf.newaxis]
h, w = train_utils.resolve_shape(reference_labels)[1:3]
embedding_height = model.scale_dimension(
h, 1.0 / decoder_output_stride)
embedding_width = model.scale_dimension(
w, 1.0 / decoder_output_stride)
reference_labels_embedding_size = tf.squeeze(
tf.image.resize_nearest_neighbor(
reference_labels, tf.stack([embedding_height, embedding_width]),
align_corners=True),
axis=0)
# Get sorted unique labels in the reference frame.
labels_in_ref_frame, _ = tf.unique(
tf.reshape(reference_labels_embedding_size, [-1]))
labels_in_ref_frame = tf.contrib.framework.sort(labels_in_ref_frame)
for idx in range(1, len(labels)):
ref_label_mask = tf.equal(
labels[idx],
labels_in_ref_frame[tf.newaxis, tf.newaxis, :])
remapped = tf.argmax(tf.cast(ref_label_mask, tf.uint8), axis=-1,
output_type=tf.int32)
# Set to 0 if label is not present
is_in_ref = tf.reduce_any(ref_label_mask, axis=-1)
remapped *= tf.cast(is_in_ref, tf.int32)
labels[idx] = remapped[..., tf.newaxis]
if sample_adjacent_and_consistent_query_frames:
if first_frame_finetuning and generate_prev_frame_mask_by_mask_damaging:
preceding_frame_label = mask_damaging.damage_masks(labels[1])
elif add_prev_frame_label:
# Discard the image of the additional frame and take the label as
# initialization for softmax feedback.
original_images = [original_images[0]] + original_images[2:]
preceding_frame_label = labels[1]
images = [images[0]] + images[2:]
labels = [labels[0]] + labels[2:]
num_frames_per_video -= 1
original_image = tf.stack(original_images, axis=0)
image = tf.stack(images, axis=0)
label = tf.stack(labels, axis=0)
else:
if label is not None:
# Need to set label shape due to batching.
label.set_shape([num_frames_per_video,
None if crop_size is None else crop_size[0],
None if crop_size is None else crop_size[1],
1])
original_image = tf.to_float(tf.zeros_like(label))
if crop_size is None:
height = tf.shape(image)[1]
width = tf.shape(image)[2]
else:
height = crop_size[0]
width = crop_size[1]
sample = {'image': image,
'image_name': image_name,
'height': height,
'width': width,
'video_id': video_id}
if label is not None:
sample['label'] = label
if object_label is not None:
sample['object_label'] = object_label
if preceding_frame_label is not None:
sample['preceding_frame_label'] = preceding_frame_label
if not is_training:
# Original image is only used during visualization.
sample['original_image'] = original_image
if is_training:
if first_frame_finetuning:
keep_input = tf.constant(True)
else:
keep_input = tf.logical_and(sampling_is_valid, tf.logical_and(
_has_enough_pixels_of_each_object_in_first_frame(
label, decoder_output_stride),
_has_foreground_and_background_in_first_frame_2(
label, decoder_output_stride)))
batched = tf.train.maybe_batch(sample,
keep_input=keep_input,
batch_size=batch_size,
num_threads=num_threads,
capacity=batch_capacity_factor * batch_size,
dynamic_pad=True)
else:
batched = tf.train.batch(sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=batch_capacity_factor * batch_size,
dynamic_pad=True)
# Flatten from [batch, num_frames_per_video, ...] to
# batch * num_frames_per_video, ...].
cropped_height = train_utils.resolve_shape(batched['image'])[2]
cropped_width = train_utils.resolve_shape(batched['image'])[3]
if num_frames_per_video is None:
first_dim = -1
else:
first_dim = batch_size * num_frames_per_video
batched['image'] = tf.reshape(batched['image'],
[first_dim, cropped_height, cropped_width, 3])
if label is not None:
batched['label'] = tf.reshape(batched['label'],
[first_dim, cropped_height, cropped_width, 1])
return batched
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Segmentation results evaluation and visualization for videos using attention.
"""
import math
import os
import time
import numpy as np
import tensorflow as tf
from feelvos import common
from feelvos import model
from feelvos.datasets import video_dataset
from feelvos.utils import embedding_utils
from feelvos.utils import eval_utils
from feelvos.utils import video_input_generator
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('vis_batch_size', 1,
'The number of images in each batch during evaluation.')
flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
flags.DEFINE_integer('output_stride', 8,
'The ratio of input to output spatial resolution.')
flags.DEFINE_string('dataset', 'davis_2016',
'Name of the segmentation dataset.')
flags.DEFINE_string('vis_split', 'val',
'Which split of the dataset used for visualizing results')
flags.DEFINE_string(
'dataset_dir',
'/cns/is-d/home/lcchen/data/pascal_voc_seg/example_sstables',
'Where the dataset resides.')
flags.DEFINE_integer('num_vis_examples', -1,
'Number of examples for visualization. If -1, use all '
'samples in the vis data.')
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_bool('save_segmentations', False, 'Whether to save the '
'segmentation masks as '
'png images. Might be slow '
'on cns.')
flags.DEFINE_bool('save_embeddings', False, 'Whether to save the embeddings as'
'pickle. Might be slow on cns.')
flags.DEFINE_bool('eval_once_and_quit', False,
'Whether to just run the eval a single time and quit '
'afterwards. Otherwise, the eval is run in a loop with '
'new checkpoints.')
flags.DEFINE_boolean('first_frame_finetuning', False,
'Whether to only sample the first frame for fine-tuning.')
# the folder where segmentations are saved.
_SEGMENTATION_SAVE_FOLDER = 'segmentation'
_EMBEDDINGS_SAVE_FOLDER = 'embeddings'
def _process_seq_data(segmentation_dir, embeddings_dir, seq_name,
predicted_labels, gt_labels, embeddings):
"""Calculates the sequence IoU and optionally save the segmentation masks.
Args:
segmentation_dir: Directory in which the segmentation results are stored.
embeddings_dir: Directory in which the embeddings are stored.
seq_name: String, the name of the sequence.
predicted_labels: Int64 np.array of shape [n_frames, height, width].
gt_labels: Ground truth labels, Int64 np.array of shape
[n_frames, height, width].
embeddings: Float32 np.array of embeddings of shape
[n_frames, decoder_height, decoder_width, embedding_dim], or None.
Returns:
The IoU for the sequence (float).
"""
sequence_dir = os.path.join(segmentation_dir, seq_name)
tf.gfile.MakeDirs(sequence_dir)
embeddings_seq_dir = os.path.join(embeddings_dir, seq_name)
tf.gfile.MakeDirs(embeddings_seq_dir)
label_set = np.unique(gt_labels[0])
ious = []
assert len(predicted_labels) == len(gt_labels)
if embeddings is not None:
assert len(predicted_labels) == len(embeddings)
for t, (predicted_label, gt_label) in enumerate(
zip(predicted_labels, gt_labels)):
if FLAGS.save_segmentations:
seg_filename = os.path.join(segmentation_dir, seq_name, '%05d.png' % t)
eval_utils.save_segmentation_with_colormap(seg_filename, predicted_label)
if FLAGS.save_embeddings:
embedding_filename = os.path.join(embeddings_dir, seq_name,
'%05d.npy' % t)
assert embeddings is not None
eval_utils.save_embeddings(embedding_filename, embeddings[t])
object_ious_t = eval_utils.calculate_multi_object_ious(
predicted_label, gt_label, label_set)
ious.append(object_ious_t)
# First and last frame are excluded in DAVIS eval.
seq_ious = np.mean(ious[1:-1], axis=0)
tf.logging.info('seq ious: %s %s', seq_name, seq_ious)
return seq_ious
def create_predictions(samples, reference_labels, first_frame_img,
model_options):
"""Predicts segmentation labels for each frame of the video.
Slower version than create_predictions_fast, but does support more options.
Args:
samples: Dictionary of input samples.
reference_labels: Int tensor of shape [1, height, width, 1].
first_frame_img: Float32 tensor of shape [height, width, 3].
model_options: An InternalModelOptions instance to configure models.
Returns:
predicted_labels: Int tensor of shape [time, height, width] of
predicted labels for each frame.
all_embeddings: Float32 tensor of shape
[time, height, width, embedding_dim], or None.
"""
def predict(args, imgs):
"""Predicts segmentation labels and softmax probabilities for each image.
Args:
args: A tuple of (predictions, softmax_probabilities), where predictions
is an int tensor of shape [1, h, w] and softmax_probabilities is a
float32 tensor of shape [1, h_decoder, w_decoder, n_objects].
imgs: Either a one-tuple of the image to predict labels for of shape
[h, w, 3], or pair of previous frame and current frame image.
Returns:
predictions: The predicted labels as int tensor of shape [1, h, w].
softmax_probabilities: The softmax probabilities of shape
[1, h_decoder, w_decoder, n_objects].
"""
if FLAGS.save_embeddings:
last_frame_predictions, last_softmax_probabilities, _ = args
else:
last_frame_predictions, last_softmax_probabilities = args
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
ref_labels_to_use = tf.concat(
[reference_labels, last_frame_predictions[..., tf.newaxis]],
axis=0)
else:
ref_labels_to_use = reference_labels
predictions, softmax_probabilities = model.predict_labels(
tf.stack((first_frame_img,) + imgs),
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
embedding_dimension=FLAGS.embedding_dimension,
reference_labels=ref_labels_to_use,
k_nearest_neighbors=FLAGS.k_nearest_neighbors,
use_softmax_feedback=FLAGS.use_softmax_feedback,
initial_softmax_feedback=last_softmax_probabilities,
embedding_seg_feature_dimension=
FLAGS.embedding_seg_feature_dimension,
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
also_return_softmax_probabilities=True,
num_frames_per_video=
(3 if FLAGS.also_attend_to_previous_frame or
FLAGS.use_softmax_feedback else 2),
normalize_nearest_neighbor_distances=
FLAGS.normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
use_local_previous_frame_attention=
FLAGS.use_local_previous_frame_attention,
previous_frame_attention_window_size=
FLAGS.previous_frame_attention_window_size,
use_first_frame_matching=FLAGS.use_first_frame_matching
)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
if FLAGS.save_embeddings:
names = [n.name for n in tf.get_default_graph().as_graph_def().node]
embedding_names = [x for x in names if 'embeddings' in x]
# This will crash when multi-scale inference is used.
assert len(embedding_names) == 1, len(embedding_names)
embedding_name = embedding_names[0] + ':0'
embeddings = tf.get_default_graph().get_tensor_by_name(embedding_name)
return predictions, softmax_probabilities, embeddings
else:
return predictions, softmax_probabilities
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, common.parse_decoder_output_stride(),
reduce_labels=False)
if FLAGS.save_embeddings:
decoder_height = tf.shape(init_softmax)[1]
decoder_width = tf.shape(init_softmax)[2]
n_frames = (3 if FLAGS.also_attend_to_previous_frame
or FLAGS.use_softmax_feedback else 2)
embeddings_init = tf.zeros((n_frames, decoder_height, decoder_width,
FLAGS.embedding_dimension))
init = (init_labels, init_softmax, embeddings_init)
else:
init = (init_labels, init_softmax)
# Do not eval the first frame again but concat the first frame ground
# truth instead.
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
elems = (samples[common.IMAGE][:-1], samples[common.IMAGE][1:])
else:
elems = (samples[common.IMAGE][1:],)
res = tf.scan(predict, elems,
initializer=init,
parallel_iterations=1,
swap_memory=True)
if FLAGS.save_embeddings:
predicted_labels, _, all_embeddings = res
first_frame_embeddings = all_embeddings[0, 0, tf.newaxis]
other_frame_embeddings = all_embeddings[:, -1]
all_embeddings = tf.concat(
[first_frame_embeddings, other_frame_embeddings], axis=0)
else:
predicted_labels, _ = res
all_embeddings = None
predicted_labels = tf.concat([reference_labels[..., 0],
tf.squeeze(predicted_labels, axis=1)],
axis=0)
return predicted_labels, all_embeddings
def create_predictions_fast(samples, reference_labels, first_frame_img,
model_options):
"""Predicts segmentation labels for each frame of the video.
Faster version than create_predictions, but does not support all options.
Args:
samples: Dictionary of input samples.
reference_labels: Int tensor of shape [1, height, width, 1].
first_frame_img: Float32 tensor of shape [height, width, 3].
model_options: An InternalModelOptions instance to configure models.
Returns:
predicted_labels: Int tensor of shape [time, height, width] of
predicted labels for each frame.
all_embeddings: Float32 tensor of shape
[time, height, width, embedding_dim], or None.
Raises:
ValueError: If FLAGS.save_embeddings is True, FLAGS.use_softmax_feedback is
False, or FLAGS.also_attend_to_previous_frame is False.
"""
if FLAGS.save_embeddings:
raise ValueError('save_embeddings does not work with '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
if not FLAGS.use_softmax_feedback:
raise ValueError('use_softmax_feedback must be True for '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
if not FLAGS.also_attend_to_previous_frame:
raise ValueError('also_attend_to_previous_frame must be True for '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
# Extract embeddings for first frame and prepare initial predictions.
first_frame_embeddings = embedding_utils.get_embeddings(
first_frame_img[tf.newaxis], model_options, FLAGS.embedding_dimension)
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, common.parse_decoder_output_stride(),
reduce_labels=False)
init = (init_labels, init_softmax, first_frame_embeddings)
def predict(args, img):
"""Predicts segmentation labels and softmax probabilities for each image.
Args:
args: tuple of
(predictions, softmax_probabilities, last_frame_embeddings), where
predictions is an int tensor of shape [1, h, w],
softmax_probabilities is a float32 tensor of shape
[1, h_decoder, w_decoder, n_objects],
and last_frame_embeddings is a float32 tensor of shape
[h_decoder, w_decoder, embedding_dimension].
img: Image to predict labels for of shape [h, w, 3].
Returns:
predictions: The predicted labels as int tensor of shape [1, h, w].
softmax_probabilities: The softmax probabilities of shape
[1, h_decoder, w_decoder, n_objects].
"""
(last_frame_predictions, last_softmax_probabilities,
prev_frame_embeddings) = args
ref_labels_to_use = tf.concat(
[reference_labels, last_frame_predictions[..., tf.newaxis]],
axis=0)
predictions, softmax_probabilities, embeddings = model.predict_labels(
img[tf.newaxis],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
embedding_dimension=FLAGS.embedding_dimension,
reference_labels=ref_labels_to_use,
k_nearest_neighbors=FLAGS.k_nearest_neighbors,
use_softmax_feedback=FLAGS.use_softmax_feedback,
initial_softmax_feedback=last_softmax_probabilities,
embedding_seg_feature_dimension=
FLAGS.embedding_seg_feature_dimension,
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
also_return_softmax_probabilities=True,
num_frames_per_video=1,
normalize_nearest_neighbor_distances=
FLAGS.normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
use_local_previous_frame_attention=
FLAGS.use_local_previous_frame_attention,
previous_frame_attention_window_size=
FLAGS.previous_frame_attention_window_size,
use_first_frame_matching=FLAGS.use_first_frame_matching,
also_return_embeddings=True,
ref_embeddings=(first_frame_embeddings, prev_frame_embeddings)
)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
return predictions, softmax_probabilities, embeddings
# Do not eval the first frame again but concat the first frame ground
# truth instead.
# If you have a lot of GPU memory, you can try to set swap_memory=False,
# and/or parallel_iterations=2.
elems = samples[common.IMAGE][1:]
res = tf.scan(predict, elems,
initializer=init,
parallel_iterations=1,
swap_memory=True)
predicted_labels, _, _ = res
predicted_labels = tf.concat([reference_labels[..., 0],
tf.squeeze(predicted_labels, axis=1)],
axis=0)
return predicted_labels
def main(unused_argv):
if FLAGS.vis_batch_size != 1:
raise ValueError('Only batch size 1 is supported for now')
data_type = 'tf_sequence_example'
# Get dataset-dependent information.
dataset = video_dataset.get_dataset(
FLAGS.dataset,
FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
data_type=data_type)
# Prepare for visualization.
tf.gfile.MakeDirs(FLAGS.vis_logdir)
segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER)
tf.gfile.MakeDirs(segmentation_dir)
embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
tf.gfile.MakeDirs(embeddings_dir)
num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0)
else FLAGS.num_vis_examples)
if FLAGS.first_frame_finetuning:
num_vis_examples = 1
tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph()
with g.as_default():
# Without setting device to CPU we run out of memory.
with tf.device('cpu:0'):
samples = video_input_generator.get(
dataset,
None,
None,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant,
preprocess_image_and_label=False,
remap_labels_to_reference_frame=False)
samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
first_frame_img = samples[common.IMAGE][0]
reference_labels = samples[common.LABEL][0, tf.newaxis]
gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
seq_name = samples[common.VIDEO_ID][0]
model_options = common.VideoModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
crop_size=None,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
all_embeddings = None
predicted_labels = create_predictions_fast(
samples, reference_labels, first_frame_img, model_options)
# If you need more options like saving embeddings, replace the call to
# create_predictions_fast with create_predictions.
tf.train.get_or_create_global_step()
saver = tf.train.Saver(slim.get_variables_to_restore())
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir,
init_op=tf.global_variables_initializer(),
summary_op=None,
summary_writer=None,
global_step=None,
saver=saver)
num_batches = int(
math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Infinite loop to visualize the results when new checkpoint is created.
while True:
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint)
all_ious = []
with sv.managed_session(FLAGS.master,
start_standard_services=False) as sess:
sv.start_queue_runners(sess)
sv.saver.restore(sess, last_checkpoint)
for batch in range(num_batches):
ops = [predicted_labels, gt_labels, seq_name]
if FLAGS.save_embeddings:
ops.append(all_embeddings)
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
res = sess.run(ops)
tf.logging.info('Forwarding done')
pred_labels_val, gt_labels_val, seq_name_val = res[:3]
if FLAGS.save_embeddings:
all_embeddings_val = res[3]
else:
all_embeddings_val = None
seq_ious = _process_seq_data(segmentation_dir, embeddings_dir,
seq_name_val, pred_labels_val,
gt_labels_val, all_embeddings_val)
all_ious.append(seq_ious)
all_ious = np.concatenate(all_ious, axis=0)
tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean())
tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
result_dir = FLAGS.vis_logdir + '/results/'
tf.gfile.MakeDirs(result_dir)
with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
f.write(str(all_ious))
if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
break
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir')
flags.mark_flag_as_required('vis_logdir')
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
*.pkl binary
*.tfrecord binary
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Filtering Variational Objectives
This folder contains a TensorFlow implementation of the algorithms from
Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. "Filtering Variational Objectives." NIPS 2017.
[https://arxiv.org/abs/1705.09279](https://arxiv.org/abs/1705.09279)
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains several sequential latent variable model implementations:
* Variational recurrent neural network (VRNN)
* Stochastic recurrent neural network (SRNN)
* Gaussian hidden Markov model with linear conditionals (GHMM)
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
#### Directory Structure
The important parts of the code are organized as follows.
```
run_fivo.py # main script, contains flag definitions
fivo
├─smc.py # a sequential Monte Carlo implementation
├─bounds.py # code for computing each bound, uses smc.py
├─runners.py # code for VRNN and SRNN training and evaluation
├─ghmm_runners.py # code for GHMM training and evaluation
├─data
| ├─datasets.py # readers for pianoroll and speech datasets
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
└─models
├─base.py # base classes used in other models
├─vrnn.py # VRNN implementation
├─srnn.py # SRNN implementation
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
bin
├─run_train.sh # an example script that runs training
├─run_eval.sh # an example script that runs evaluation
├─run_sample.sh # an example script that runs sampling
├─run_tests.sh # a script that runs all tests
└─download_pianorolls.sh # a script that downloads pianoroll files
```
### Pianorolls
Requirements before we start:
* TensorFlow (see [tensorflow.org](http://tensorflow.org) for how to install)
* [scipy](https://www.scipy.org/)
* [sonnet](https://github.com/deepmind/sonnet)
#### Download the Data
The pianoroll datasets are encoded as pickled sparse arrays and are available at [http://www-etud.iro.umontreal.ca/~boulanni/icml2012](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). You can use the script `bin/download_pianorolls.sh` to download the files into a directory of your choosing.
```
export PIANOROLL_DIR=~/pianorolls
mkdir $PIANOROLL_DIR
sh bin/download_pianorolls.sh $PIANOROLL_DIR
```
#### Preprocess the Data
The script `calculate_pianoroll_mean.py` loads a pianoroll pickle file, calculates the mean, updates the pickle file to include the mean under the key `train_mean`, and writes the file back to disk in-place. You should do this for all pianoroll datasets you wish to train on.
```
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/piano-midi.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/nottingham.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/musedata.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
```
#### Training
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
```
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output that looks something like this (with extra logging cruft):
```
Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
Step 1, fivo bound per timestep: -11.322491
global_step/sec: 7.49971
Step 101, fivo bound per timestep: -11.399275
global_step/sec: 8.04498
Step 201, fivo bound per timestep: -11.174991
global_step/sec: 8.03989
Step 301, fivo bound per timestep: -11.073008
```
#### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
```
python run_fivo.py \
--mode=eval \
--split=test \
--alsologtostderr \
--logdir=/tmp/fivo \
--model=vrnn \
--batch_size=4 \
--num_samples=4 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output like this:
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Model restored from step 0, evaluating.
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
#### Sampling
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
```
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
```
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
You should see very little output.
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Running local_init_op.
Done running local_init_op.
```
Loading the samples with `np.load` confirms that we conditioned the model on 4
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
```
>>> import numpy as np
>>> x = np.load("/tmp/fivo/samples.npz")
>>> x[()]['prefixes'].shape
(25, 4, 88)
>>> x[()]['samples'].shape
(50, 4, 4, 88)
```
### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
#### Preprocess TIMIT
We preprocess TIMIT (as described in our paper) and write it out to a series of TFRecord files. To prepare the TIMIT dataset use the script `create_timit_dataset.py`
```
export $TIMIT_DIR=~/timit_dataset
mkdir $TIMIT_DIR
python data/create_timit_dataset.py \
--raw_timit_dir=$RAW_TIMIT_DIR \
--out_dir=$TIMIT_DIR
```
You should see this exact output:
```
4389 train / 231 valid / 1680 test
train mean: 0.006060 train std: 548.136169
```
#### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched.
```
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech"
```
Evaluation and sampling are similar.
### Tests
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
### Contact
This codebase is maintained by Dieterich Lawson. For questions and issues please open an issue on the tensorflow/models issues tracker and assign it to @dieterichlawson.
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# A script to download the pianoroll datasets.
# Accepts one argument, the directory to put the files in
if [ -z "$1" ]
then
echo "Error, must provide a directory to download the files to."
exit
fi
echo "Downloading datasets into $1"
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle" > $1/piano-midi.de.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle" > $1/nottingham.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle" > $1/musedata.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle" > $1/jsb.pkl
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of sampling from the model.
PIANOROLL_DIR=$HOME/pianorolls
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
python -m fivo.smc_test && \
python -m fivo.bounds_test && \
python -m fivo.nested_utils_test && \
python -m fivo.data.datasets_test && \
python -m fivo.models.ghmm_test && \
python -m fivo.models.vrnn_test && \
python -m fivo.models.srnn_test && \
python -m fivo.ghmm_runners_test && \
python -m fivo.runners_test
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of running training.
PIANOROLL_DIR=$HOME/pianorolls
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
An experimental codebase for running simple examples.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import summary_utils as summ
Loss = namedtuple("Loss", "name loss vars")
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
def iwae(model, observation, num_timesteps, num_samples=1,
summarize=False):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.shape(observation)[0]
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
log_weights = []
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
states[-1], observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
if summarize:
weight_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weight_acc, perm=[1, 0]),
allow_nan_stats=False)
weight_entropy = weight_dist.entropy()
weight_entropy = tf.reduce_mean(weight_entropy)
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
log_weights.append(log_weight_acc)
# Compute the lower bound on the log evidence.
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
loss = -tf.reduce_mean(log_p_hat)
losses = [Loss("log_p_hat", loss)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return log_p_hat, losses, tf.no_op(), states[1:], log_weights
def multinomial_resampling(log_weights, states, n, b):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
resampling_parameters = tf.transpose(log_weights, perm=[1,0])
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
ancestors = tf.stop_gradient(
resampling_dist.sample(sample_shape=n))
log_probs = resampling_dist.log_prob(ancestors)
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def stratified_resampling(log_weights, states, n, b):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs = resampling_parameters,
allow_nan_stats=False)
ancestors = tf.stop_gradient(
resampling_dist.sample())
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def systematic_resampling(log_weights, states, n, b):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs=resampling_parameters,
allow_nan_stats=True)
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def log_blend(inputs, weights):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
my_max = tf.stop_gradient(
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
)
# Don't ask.
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
return blended
def relaxed_resampling(log_weights, states, num_samples, batch_size,
log_r_x=None, blend_type="log", temperature=0.5,
straight_through=False):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
state_dim = states[0].get_shape().as_list()[-1]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
temperature,
logits=resampling_parameters)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors = resampling_dist.sample(sample_shape=num_samples)
if straight_through:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
dtype=ancestors.dtype)
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
log_probs = resampling_dist.log_prob(ancestors)
if log_r_x is not None and blend_type == "log":
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
log_r_x = log_blend(log_r_x, ancestors)
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
elif log_r_x is not None and blend_type == "linear":
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states.append(log_r_x)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
resampled_states = []
for state in states:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state = tf.matmul(state, ancestor_inds)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
resampled_states.append(next_state)
new_dist = tf.contrib.distributions.Categorical(
logits=resampling_parameters)
if log_r_x is not None and blend_type == "linear":
# If blend type is linear pop off log_r that we added to the states.
log_r_x = tf.squeeze(resampled_states[-1])
resampled_states = resampled_states[:-1]
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
def fivo(model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
use_resampling_grads=True,
resampling_type="multinomial",
resampling_temperature=0.5,
aux=True,
summarize=False):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
log_weights = []
log_weights_all = []
log_p_hats = []
resampling_log_probs = []
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
prev_state, observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
if aux:
if t == num_timesteps - 1:
log_weight -= prev_log_r_zt
else:
log_weight += log_r_zt - prev_log_r_zt
prev_log_r_zt = log_r_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
log_weights_all.append(log_weight_acc)
if resampling_schedule[t]:
# These objects will be resampled
to_resample = [states[-1]]
if aux and "relaxed" not in resampling_type:
to_resample.append(prev_log_r_zt)
# do the resampling
if resampling_type == "multinomial":
(resampled,
resampling_log_prob,
_, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "stratified":
(resampled,
resampling_log_prob,
_, _, _) = stratified_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "systematic":
(resampled,
resampling_log_prob,
_, _, _) = systematic_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif "relaxed" in resampling_type:
if aux:
if resampling_type == "relaxed-logblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="log")
elif resampling_type == "relaxed-linearblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="linear")
elif resampling_type == "relaxed-stateblend":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
elif resampling_type == "relaxed-stateblend-st":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
straight_through=True)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
else:
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
prev_state = resampled[0]
if aux and "relaxed" not in resampling_type:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt = tf.squeeze(resampled[1])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = states[-1]
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
if use_resampling_grads and any(resampling_schedule):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards = tf.stop_gradient(
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
centered_rewards = rewards - baseline
if summarize:
summ.summarize_learning_signal(rewards, "rewards")
summ.summarize_learning_signal(centered_rewards, "centered_rewards")
# compute the loss tensor.
resampling_grads = tf.reduce_sum(
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
else:
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
maintain_baseline_op = tf.no_op()
log_p_hat /= num_timesteps
# we clip off the initial state before returning.
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
def fivo_aux_td(
model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
summarize=False):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
# must be pre-resampling
log_rs = []
# must be post-resampling
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
log_r_tildes = []
log_p_xs = []
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights = []
# contains weight at each timestep before resampling
log_weights_all = []
log_p_hats = []
for t in xrange(num_timesteps):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
# Compute the log weight without log r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log r.
if t == num_timesteps - 1:
log_r = tf.zeros_like(prev_log_r)
else:
p_mu = p_ztplus1.mean()
p_sigma_sq = p_ztplus1.variance()
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight += log_r - prev_log_r
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
# Update accumulators
states.append(zt)
log_weights_all.append(log_weight_acc)
log_p_xs.append(log_p_x_given_z)
log_rs.append(log_r)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes.append(prev_log_r_tilde)
# optionally resample
if resampling_schedule[t]:
# These objects will be resampled
if t < num_timesteps - 1:
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
else:
to_resample = [zt, log_r]
(resampled,
_, _, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
prev_state = resampled[0]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r = tf.squeeze(resampled[1])
if t < num_timesteps -1:
r_tilde_params.append((resampled[2], resampled[3]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = zt
prev_log_r = log_r
if t < num_timesteps - 1:
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x = tf.reshape(tf.stack(log_p_xs),
[num_timesteps, num_samples, batch_size])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r = tf.reshape(tf.stack(log_rs),
[num_timesteps, num_samples, batch_size])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
[num_timesteps, num_samples, batch_size])
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
keepdims=True)
bellman_sos = tf.reduce_mean(tf.square(
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
tf.summary.scalar("bellman_loss", bellman_loss)
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
log_p_hat_collection = list(set(tf.trainable_variables()) -
set(tf.get_collection("R_TILDE_VARS")))
for v in log_p_hat_collection:
tf.add_to_collection("LOG_P_HAT_VARS", v)
log_p_hat /= num_timesteps
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment