Unverified Commit e1ae37c4 authored by aquariusjay's avatar aquariusjay Committed by GitHub
Browse files

Open-source FEELVOS model, which was developed by Paul Voigtlaender during his...

Open-source FEELVOS model, which was developed by Paul Voigtlaender during his 2018 summer internship at Google. The work has been accepted to CVPR 2019. (#6274)
parent 5274ec8b
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script is used to run local training on DAVIS 2017. Users could also
# modify from this script for their use case. See eval.sh for an example of
# local inference with a pre-trained model.
#
# Note that this script runs local training with a single GPU and a smaller crop
# and batch size, while in the paper, we trained our models with 16 GPUS with
# --num_clones=2, --train_batch_size=6, --num_replicas=8,
# --training_number_of_steps=200000, --train_crop_size=465,
# --train_crop_size=465.
#
# Usage:
# # From the tensorflow/models/research/feelvos directory.
# sh ./train.sh
#
#
# Exit immediately if a command exits with a non-zero status.
set -e
# Move one-level up to tensorflow/models/research directory.
cd ..
# Update PYTHONPATH.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim:`pwd`/feelvos
# Set up the working environment.
CURRENT_DIR=$(pwd)
WORK_DIR="${CURRENT_DIR}/feelvos"
# Set up the working directories.
DATASET_DIR="datasets"
DAVIS_FOLDER="davis17"
DAVIS_DATASET="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/tfrecord"
EXP_FOLDER="exp/train"
TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${DAVIS_FOLDER}/${EXP_FOLDER}/train"
mkdir -p ${TRAIN_LOGDIR}
# Go to datasets folder and download and convert the DAVIS 2017 dataset.
DATASET_DIR="datasets"
cd "${WORK_DIR}/${DATASET_DIR}"
sh download_and_convert_davis17.sh
# Go to models folder and download and unpack the COCO pre-trained model.
MODELS_DIR="models"
mkdir -p "${WORK_DIR}/${MODELS_DIR}"
cd "${WORK_DIR}/${MODELS_DIR}"
if [ ! -d "xception_65_coco_pretrained" ]; then
wget http://download.tensorflow.org/models/xception_65_coco_pretrained_2018_10_02.tar.gz
tar -xvf xception_65_coco_pretrained_2018_10_02.tar.gz
rm xception_65_coco_pretrained_2018_10_02.tar.gz
fi
INIT_CKPT="${WORK_DIR}/${MODELS_DIR}/xception_65_coco_pretrained/x65-b2u1s2p-d48-2-3x256-sc-cr300k_init.ckpt"
# Go back to orignal directory.
cd "${CURRENT_DIR}"
python "${WORK_DIR}"/train.py \
--dataset=davis_2017 \
--dataset_dir="${DAVIS_DATASET}" \
--train_logdir="${TRAIN_LOGDIR}" \
--tf_initial_checkpoint="${INIT_CKPT}" \
--logtostderr \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--decoder_output_stride=4 \
--model_variant=xception_65 \
--multi_grid=1 \
--multi_grid=1 \
--multi_grid=1 \
--output_stride=16 \
--weight_decay=0.00004 \
--num_clones=1 \
--train_batch_size=1 \
--train_crop_size=300 \
--train_crop_size=300
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for the instance embedding for segmentation."""
import numpy as np
import tensorflow as tf
from deeplab import model
from deeplab.core import preprocess_utils
from feelvos.utils import mask_damaging
slim = tf.contrib.slim
resolve_shape = preprocess_utils.resolve_shape
WRONG_LABEL_PADDING_DISTANCE = 1e20
# With correlation_cost local matching will be much faster. But we provide a
# slow fallback for convenience.
USE_CORRELATION_COST = False
if USE_CORRELATION_COST:
# pylint: disable=g-import-not-at-top
from correlation_cost.python.ops import correlation_cost_op
def pairwise_distances(x, y):
"""Computes pairwise squared l2 distances between tensors x and y.
Args:
x: Tensor of shape [n, feature_dim].
y: Tensor of shape [m, feature_dim].
Returns:
Float32 distances tensor of shape [n, m].
"""
# d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
# = sum(x[i]^2, 1) + sum(y[j]^2, 1) - 2 * x[i] * y[j]'
xs = tf.reduce_sum(x * x, axis=1)[:, tf.newaxis]
ys = tf.reduce_sum(y * y, axis=1)[tf.newaxis, :]
d = xs + ys - 2 * tf.matmul(x, y, transpose_b=True)
return d
def pairwise_distances2(x, y):
"""Computes pairwise squared l2 distances between tensors x and y.
Naive implementation, high memory use. Could be useful to test the more
efficient implementation.
Args:
x: Tensor of shape [n, feature_dim].
y: Tensor of shape [m, feature_dim].
Returns:
distances of shape [n, m].
"""
return tf.reduce_sum(tf.squared_difference(
x[:, tf.newaxis], y[tf.newaxis, :]), axis=-1)
def cross_correlate(x, y, max_distance=9):
"""Efficiently computes the cross correlation of x and y.
Optimized implementation using correlation_cost.
Note that we do not normalize by the feature dimension.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 tensor of shape [height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('cross_correlation'):
corr = correlation_cost_op.correlation_cost(
x[tf.newaxis], y[tf.newaxis], kernel_size=1,
max_displacement=max_distance, stride_1=1, stride_2=1,
pad=max_distance)
corr = tf.squeeze(corr, axis=0)
# This correlation implementation takes the mean over the feature_dim,
# but we want sum here, so multiply by feature_dim.
feature_dim = resolve_shape(x)[-1]
corr *= feature_dim
return corr
def local_pairwise_distances(x, y, max_distance=9):
"""Computes pairwise squared l2 distances using a local search window.
Optimized implementation using correlation_cost.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 distances tensor of shape
[height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('local_pairwise_distances'):
# d[i,j] = (x[i] - y[j]) * (x[i] - y[j])'
# = sum(x[i]^2, -1) + sum(y[j]^2, -1) - 2 * x[i] * y[j]'
corr = cross_correlate(x, y, max_distance=max_distance)
xs = tf.reduce_sum(x * x, axis=2)[..., tf.newaxis]
ys = tf.reduce_sum(y * y, axis=2)[..., tf.newaxis]
ones_ys = tf.ones_like(ys)
ys = cross_correlate(ones_ys, ys, max_distance=max_distance)
d = xs + ys - 2 * corr
# Boundary should be set to Inf.
boundary = tf.equal(
cross_correlate(ones_ys, ones_ys, max_distance=max_distance), 0)
d = tf.where(boundary, tf.fill(tf.shape(d), tf.constant(np.float('inf'))),
d)
return d
def local_pairwise_distances2(x, y, max_distance=9):
"""Computes pairwise squared l2 distances using a local search window.
Naive implementation using map_fn.
Used as a slow fallback for when correlation_cost is not available.
Args:
x: Float32 tensor of shape [height, width, feature_dim].
y: Float32 tensor of shape [height, width, feature_dim].
max_distance: Integer, the maximum distance in pixel coordinates
per dimension which is considered to be in the search window.
Returns:
Float32 distances tensor of shape
[height, width, (2 * max_distance + 1) ** 2].
"""
with tf.name_scope('local_pairwise_distances2'):
padding_val = 1e20
padded_y = tf.pad(y, [[max_distance, max_distance],
[max_distance, max_distance], [0, 0]],
constant_values=padding_val)
height, width, _ = resolve_shape(x)
dists = []
for y_start in range(2 * max_distance + 1):
y_end = y_start + height
y_slice = padded_y[y_start:y_end]
for x_start in range(2 * max_distance + 1):
x_end = x_start + width
offset_y = y_slice[:, x_start:x_end]
dist = tf.reduce_sum(tf.squared_difference(x, offset_y), axis=2)
dists.append(dist)
dists = tf.stack(dists, axis=2)
return dists
def majority_vote(labels):
"""Performs a label majority vote along axis 1.
Second try, hopefully this time more efficient.
We assume that the labels are contiguous starting from 0.
It will also work for non-contiguous labels, but be inefficient.
Args:
labels: Int tensor of shape [n, k]
Returns:
The majority of labels along axis 1
"""
max_label = tf.reduce_max(labels)
one_hot = tf.one_hot(labels, depth=max_label + 1)
summed = tf.reduce_sum(one_hot, axis=1)
majority = tf.argmax(summed, axis=1)
return majority
def assign_labels_by_nearest_neighbors(reference_embeddings, query_embeddings,
reference_labels, k=1):
"""Segments by nearest neighbor query wrt the reference frame.
Args:
reference_embeddings: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the reference frame
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames
reference_labels: Tensor of shape [height, width, 1], the class labels of
the reference frame
k: Integer, the number of nearest neighbors to use
Returns:
The labels of the nearest neighbors as [n_query_frames, height, width, 1]
tensor
Raises:
ValueError: If k < 1.
"""
if k < 1:
raise ValueError('k must be at least 1')
dists = flattened_pairwise_distances(reference_embeddings, query_embeddings)
if k == 1:
nn_indices = tf.argmin(dists, axis=1)[..., tf.newaxis]
else:
_, nn_indices = tf.nn.top_k(-dists, k, sorted=False)
reference_labels = tf.reshape(reference_labels, [-1])
nn_labels = tf.gather(reference_labels, nn_indices)
if k == 1:
nn_labels = tf.squeeze(nn_labels, axis=1)
else:
nn_labels = majority_vote(nn_labels)
height = tf.shape(reference_embeddings)[0]
width = tf.shape(reference_embeddings)[1]
n_query_frames = query_embeddings.shape[0]
nn_labels = tf.reshape(nn_labels, [n_query_frames, height, width, 1])
return nn_labels
def flattened_pairwise_distances(reference_embeddings, query_embeddings):
"""Calculates flattened tensor of pairwise distances between ref and query.
Args:
reference_embeddings: Tensor of shape [..., embedding_dim],
the embedding vectors for the reference frame
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames.
Returns:
A distance tensor of shape [reference_embeddings.size / embedding_dim,
query_embeddings.size / embedding_dim]
"""
embedding_dim = resolve_shape(query_embeddings)[-1]
reference_embeddings = tf.reshape(reference_embeddings, [-1, embedding_dim])
first_dim = -1
query_embeddings = tf.reshape(query_embeddings, [first_dim, embedding_dim])
dists = pairwise_distances(query_embeddings, reference_embeddings)
return dists
def nearest_neighbor_features_per_object(
reference_embeddings, query_embeddings, reference_labels,
max_neighbors_per_object, k_nearest_neighbors, gt_ids=None, n_chunks=100):
"""Calculates the distance to the nearest neighbor per object.
For every pixel of query_embeddings calculate the distance to the
nearest neighbor in the (possibly subsampled) reference_embeddings per object.
Args:
reference_embeddings: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings: Tensor of shape [n_query_images, height, width,
embedding_dim], the embedding vectors for the query frames.
reference_labels: Tensor of shape [height, width, 1], the class labels of
the reference frame.
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling,
or 0 for no subsampling.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
gt_ids: Int tensor of shape [n_objs] of the sorted unique ground truth
ids in the first frame. If None, it will be derived from
reference_labels.
n_chunks: Integer, the number of chunks to use to save memory
(set to 1 for no chunking).
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[n_query_images, height, width, n_objects, feature_dim].
gt_ids: An int32 tensor of the unique sorted object ids present
in the reference labels.
"""
with tf.name_scope('nn_features_per_object'):
reference_labels_flat = tf.reshape(reference_labels, [-1])
if gt_ids is None:
ref_obj_ids, _ = tf.unique(reference_labels_flat)
ref_obj_ids = tf.contrib.framework.sort(ref_obj_ids)
gt_ids = ref_obj_ids
embedding_dim = resolve_shape(reference_embeddings)[-1]
reference_embeddings_flat = tf.reshape(reference_embeddings,
[-1, embedding_dim])
reference_embeddings_flat, reference_labels_flat = (
subsample_reference_embeddings_and_labels(reference_embeddings_flat,
reference_labels_flat,
gt_ids,
max_neighbors_per_object))
shape = resolve_shape(query_embeddings)
query_embeddings_flat = tf.reshape(query_embeddings, [-1, embedding_dim])
nn_features = _nearest_neighbor_features_per_object_in_chunks(
reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
gt_ids, k_nearest_neighbors, n_chunks)
nn_features_dim = resolve_shape(nn_features)[-1]
nn_features_reshaped = tf.reshape(nn_features,
tf.stack(shape[:3] + [tf.size(gt_ids),
nn_features_dim]))
return nn_features_reshaped, gt_ids
def _nearest_neighbor_features_per_object_in_chunks(
reference_embeddings_flat, query_embeddings_flat, reference_labels_flat,
ref_obj_ids, k_nearest_neighbors, n_chunks):
"""Calculates the nearest neighbor features per object in chunks to save mem.
Uses chunking to bound the memory use.
Args:
reference_embeddings_flat: Tensor of shape [n, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings_flat: Tensor of shape [m, embedding_dim], the embedding
vectors for the query frames.
reference_labels_flat: Tensor of shape [n], the class labels of the
reference frame.
ref_obj_ids: int tensor of unique object ids in the reference labels.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
n_chunks: Integer, the number of chunks to use to save memory
(set to 1 for no chunking).
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[m, n_objects, feature_dim].
"""
chunk_size = tf.cast(tf.ceil(tf.cast(tf.shape(query_embeddings_flat)[0],
tf.float32) / n_chunks), tf.int32)
wrong_label_mask = tf.not_equal(reference_labels_flat,
ref_obj_ids[:, tf.newaxis])
all_features = []
for n in range(n_chunks):
if n_chunks == 1:
query_embeddings_flat_chunk = query_embeddings_flat
else:
chunk_start = n * chunk_size
chunk_end = (n + 1) * chunk_size
query_embeddings_flat_chunk = query_embeddings_flat[chunk_start:chunk_end]
# Use control dependencies to make sure that the chunks are not processed
# in parallel which would prevent any peak memory savings.
with tf.control_dependencies(all_features):
features = _nn_features_per_object_for_chunk(
reference_embeddings_flat, query_embeddings_flat_chunk,
wrong_label_mask, k_nearest_neighbors
)
all_features.append(features)
if n_chunks == 1:
nn_features = all_features[0]
else:
nn_features = tf.concat(all_features, axis=0)
return nn_features
def _nn_features_per_object_for_chunk(
reference_embeddings, query_embeddings, wrong_label_mask,
k_nearest_neighbors):
"""Extracts features for each object using nearest neighbor attention.
Args:
reference_embeddings: Tensor of shape [n_chunk, embedding_dim],
the embedding vectors for the reference frame.
query_embeddings: Tensor of shape [m_chunk, embedding_dim], the embedding
vectors for the query frames.
wrong_label_mask:
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
Returns:
nn_features: A float32 tensor of nearest neighbor features of shape
[m_chunk, n_objects, feature_dim].
"""
reference_embeddings_key = reference_embeddings
query_embeddings_key = query_embeddings
dists = flattened_pairwise_distances(reference_embeddings_key,
query_embeddings_key)
dists = (dists[:, tf.newaxis, :] +
tf.cast(wrong_label_mask[tf.newaxis, :, :], tf.float32) *
WRONG_LABEL_PADDING_DISTANCE)
if k_nearest_neighbors == 1:
features = tf.reduce_min(dists, axis=2, keepdims=True)
else:
# Find the closest k and combine them according to attention_feature_type
dists, _ = tf.nn.top_k(-dists, k=k_nearest_neighbors)
dists = -dists
# If not enough real neighbors were found, pad with the farthest real
# neighbor.
valid_mask = tf.less(dists, WRONG_LABEL_PADDING_DISTANCE)
masked_dists = dists * tf.cast(valid_mask, tf.float32)
pad_dist = tf.tile(tf.reduce_max(masked_dists, axis=2)[..., tf.newaxis],
multiples=[1, 1, k_nearest_neighbors])
dists = tf.where(valid_mask, dists, pad_dist)
# take mean of distances
features = tf.reduce_mean(dists, axis=2, keepdims=True)
return features
def create_embedding_segmentation_features(features, feature_dimension,
n_layers, kernel_size, reuse,
atrous_rates=None):
"""Extracts features which can be used to estimate the final segmentation.
Args:
features: input features of shape [batch, height, width, features]
feature_dimension: Integer, the dimensionality used in the segmentation
head layers.
n_layers: Integer, the number of layers in the segmentation head.
kernel_size: Integer, the kernel size used in the segmentation head.
reuse: reuse mode for the variable_scope.
atrous_rates: List of integers of length n_layers, the atrous rates to use.
Returns:
Features to be used to estimate the segmentation labels of shape
[batch, height, width, embedding_seg_feat_dim].
"""
if atrous_rates is None or not atrous_rates:
atrous_rates = [1 for _ in range(n_layers)]
assert len(atrous_rates) == n_layers
with tf.variable_scope('embedding_seg', reuse=reuse):
for n in range(n_layers):
features = model.split_separable_conv2d(
features, feature_dimension, kernel_size=kernel_size,
rate=atrous_rates[n], scope='split_separable_conv2d_{}'.format(n))
return features
def add_image_summaries(images, nn_features, logits, batch_size,
prev_frame_nn_features=None):
"""Adds image summaries of input images, attention features and logits.
Args:
images: Image tensor of shape [batch, height, width, channels].
nn_features: Nearest neighbor attention features of shape
[batch_size, height, width, n_objects, 1].
logits: Float32 tensor of logits.
batch_size: Integer, the number of videos per clone per mini-batch.
prev_frame_nn_features: Nearest neighbor attention features wrt. the
last frame of shape [batch_size, height, width, n_objects, 1].
Can be None.
"""
# Separate reference and query images.
reshaped_images = tf.reshape(images, tf.stack(
[batch_size, -1] + resolve_shape(images)[1:]))
reference_images = reshaped_images[:, 0]
query_images = reshaped_images[:, 1:]
query_images_reshaped = tf.reshape(query_images, tf.stack(
[-1] + resolve_shape(images)[1:]))
tf.summary.image('ref_images', reference_images, max_outputs=batch_size)
tf.summary.image('query_images', query_images_reshaped, max_outputs=10)
predictions = tf.cast(
tf.argmax(logits, axis=-1), tf.uint8)[..., tf.newaxis]
# Scale up so that we can actually see something.
tf.summary.image('predictions', predictions * 32, max_outputs=10)
# We currently only show the first dimension of the features for background
# and the first foreground object.
tf.summary.image('nn_fg_features', nn_features[..., 0:1, 0],
max_outputs=batch_size)
if prev_frame_nn_features is not None:
tf.summary.image('nn_fg_features_prev', prev_frame_nn_features[..., 0:1, 0],
max_outputs=batch_size)
tf.summary.image('nn_bg_features', nn_features[..., 1:2, 0],
max_outputs=batch_size)
if prev_frame_nn_features is not None:
tf.summary.image('nn_bg_features_prev',
prev_frame_nn_features[..., 1:2, 0],
max_outputs=batch_size)
def get_embeddings(images, model_options, embedding_dimension):
"""Extracts embedding vectors for images. Should only be used for inference.
Args:
images: A tensor of shape [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
embedding_dimension: Integer, the dimension of the embedding.
Returns:
embeddings: A tensor of shape [batch, height, width, embedding_dimension].
"""
features, end_points = model.extract_features(
images,
model_options,
is_training=False)
if model_options.decoder_output_stride is not None:
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = model.scale_dimension(
height, 1.0 / model_options.decoder_output_stride)
decoder_width = model.scale_dimension(
width, 1.0 / model_options.decoder_output_stride)
features = model.refine_by_decoder(
features,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
is_training=False)
with tf.variable_scope('embedding'):
embeddings = split_separable_conv2d_with_identity_initializer(
features, embedding_dimension, scope='split_separable_conv2d')
return embeddings
def get_logits_with_matching(images,
model_options,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False,
reference_labels=None,
batch_size=None,
num_frames_per_video=None,
embedding_dimension=None,
max_neighbors_per_object=0,
k_nearest_neighbors=1,
use_softmax_feedback=True,
initial_softmax_feedback=None,
embedding_seg_feature_dimension=256,
embedding_seg_n_layers=4,
embedding_seg_kernel_size=7,
embedding_seg_atrous_rates=None,
normalize_nearest_neighbor_distances=True,
also_attend_to_previous_frame=True,
damage_initial_previous_frame_mask=False,
use_local_previous_frame_attention=True,
previous_frame_attention_window_size=15,
use_first_frame_matching=True,
also_return_embeddings=False,
ref_embeddings=None):
"""Gets the logits by atrous/image spatial pyramid pooling using attention.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
reference_labels: The segmentation labels of the reference frame on which
attention is applied.
batch_size: Integer, the number of videos on a batch
num_frames_per_video: Integer, the number of frames per video
embedding_dimension: Integer, the dimension of the embedding
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling.
Can be 0 for no subsampling.
k_nearest_neighbors: Integer, the number of nearest neighbors to use.
use_softmax_feedback: Boolean, whether to give the softmax predictions of
the last frame as additional input to the segmentation head.
initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
initialize the softmax predictions used for the feedback loop.
Only has an effect if use_softmax_feedback is True.
embedding_seg_feature_dimension: Integer, the dimensionality used in the
segmentation head layers.
embedding_seg_n_layers: Integer, the number of layers in the segmentation
head.
embedding_seg_kernel_size: Integer, the kernel size used in the
segmentation head.
embedding_seg_atrous_rates: List of integers of length
embedding_seg_n_layers, the atrous rates to use for the segmentation head.
normalize_nearest_neighbor_distances: Boolean, whether to normalize the
nearest neighbor distances to [0,1] using sigmoid, scale and shift.
also_attend_to_previous_frame: Boolean, whether to also use nearest
neighbor attention with respect to the previous frame.
damage_initial_previous_frame_mask: Boolean, whether to artificially damage
the initial previous frame mask. Only has an effect if
also_attend_to_previous_frame is True.
use_local_previous_frame_attention: Boolean, whether to restrict the
previous frame attention to a local search window.
Only has an effect, if also_attend_to_previous_frame is True.
previous_frame_attention_window_size: Integer, the window size used for
local previous frame attention, if use_local_previous_frame_attention
is True.
use_first_frame_matching: Boolean, whether to extract features by matching
to the reference frame. This should always be true except for ablation
experiments.
also_return_embeddings: Boolean, whether to return the embeddings as well.
ref_embeddings: Tuple of
(first_frame_embeddings, previous_frame_embeddings),
each of shape [batch, height, width, embedding_dimension], or None.
Returns:
outputs_to_logits: A map from output_type to logits.
If also_return_embeddings is True, it will also return an embeddings
tensor of shape [batch, height, width, embedding_dimension].
"""
features, end_points = model.extract_features(
images,
model_options,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
if model_options.decoder_output_stride is not None:
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = model.scale_dimension(
height, 1.0 / model_options.decoder_output_stride)
decoder_width = model.scale_dimension(
width, 1.0 / model_options.decoder_output_stride)
features = model.refine_by_decoder(
features,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
with tf.variable_scope('embedding', reuse=reuse):
embeddings = split_separable_conv2d_with_identity_initializer(
features, embedding_dimension, scope='split_separable_conv2d')
embeddings = tf.identity(embeddings, name='embeddings')
scaled_reference_labels = tf.image.resize_nearest_neighbor(
reference_labels,
resolve_shape(embeddings, 4)[1:3],
align_corners=True)
h, w = decoder_height, decoder_width
if num_frames_per_video is None:
num_frames_per_video = tf.size(embeddings) // (
batch_size * h * w * embedding_dimension)
new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
reshaped_reference_labels = tf.reshape(scaled_reference_labels,
new_labels_shape)
new_embeddings_shape = tf.stack([batch_size,
num_frames_per_video, h, w,
embedding_dimension])
reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
all_nn_features = []
all_ref_obj_ids = []
# To keep things simple, we do all this separate for each sequence for now.
for n in range(batch_size):
embedding = reshaped_embeddings[n]
if ref_embeddings is None:
n_chunks = 100
reference_embedding = embedding[0]
if also_attend_to_previous_frame or use_softmax_feedback:
queries_embedding = embedding[2:]
else:
queries_embedding = embedding[1:]
else:
if USE_CORRELATION_COST:
n_chunks = 20
else:
n_chunks = 500
reference_embedding = ref_embeddings[0][n]
queries_embedding = embedding
reference_labels = reshaped_reference_labels[n][0]
nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
reference_embedding, queries_embedding, reference_labels,
max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
if normalize_nearest_neighbor_distances:
nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
all_nn_features.append(nn_features_n)
all_ref_obj_ids.append(ref_obj_ids)
feat_dim = resolve_shape(features)[-1]
features = tf.reshape(features, tf.stack(
[batch_size, num_frames_per_video, h, w, feat_dim]))
if ref_embeddings is None:
# Strip the features for the reference frame.
if also_attend_to_previous_frame or use_softmax_feedback:
features = features[:, 2:]
else:
features = features[:, 1:]
# To keep things simple, we do all this separate for each sequence for now.
outputs_to_logits = {output: [] for
output in model_options.outputs_to_num_classes}
for n in range(batch_size):
features_n = features[n]
nn_features_n = all_nn_features[n]
nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
n_objs = tf.shape(nn_features_n_tr)[0]
# Repeat features for every object.
features_n_tiled = tf.tile(features_n[tf.newaxis],
multiples=[n_objs, 1, 1, 1, 1])
prev_frame_labels = None
if also_attend_to_previous_frame:
prev_frame_labels = reshaped_reference_labels[n, 1]
if is_training and damage_initial_previous_frame_mask:
# Damage the previous frame masks.
prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
dilate=False)
tf.summary.image('prev_frame_labels',
tf.cast(prev_frame_labels[tf.newaxis],
tf.uint8) * 32)
initial_softmax_feedback_n = create_initial_softmax_from_labels(
prev_frame_labels, reshaped_reference_labels[n][0],
decoder_output_stride=None, reduce_labels=True)
elif initial_softmax_feedback is not None:
initial_softmax_feedback_n = initial_softmax_feedback[n]
else:
initial_softmax_feedback_n = None
if initial_softmax_feedback_n is None:
last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
else:
last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
..., tf.newaxis]
assert len(model_options.outputs_to_num_classes) == 1
output = model_options.outputs_to_num_classes.keys()[0]
logits = []
n_ref_frames = 1
prev_frame_nn_features_n = None
if also_attend_to_previous_frame or use_softmax_feedback:
n_ref_frames += 1
if ref_embeddings is not None:
n_ref_frames = 0
for t in range(num_frames_per_video - n_ref_frames):
to_concat = [features_n_tiled[:, t]]
if use_first_frame_matching:
to_concat.append(nn_features_n_tr[:, t])
if use_softmax_feedback:
to_concat.append(last_softmax)
if also_attend_to_previous_frame:
assert normalize_nearest_neighbor_distances, (
'previous frame attention currently only works when normalized '
'distances are used')
embedding = reshaped_embeddings[n]
if ref_embeddings is None:
last_frame_embedding = embedding[t + 1]
query_embeddings = embedding[t + 2, tf.newaxis]
else:
last_frame_embedding = ref_embeddings[1][0]
query_embeddings = embedding
if use_local_previous_frame_attention:
assert query_embeddings.shape[0] == 1
prev_frame_nn_features_n = (
local_previous_frame_nearest_neighbor_features_per_object(
last_frame_embedding,
query_embeddings[0],
prev_frame_labels,
all_ref_obj_ids[n],
max_distance=previous_frame_attention_window_size)
)
else:
prev_frame_nn_features_n, _ = (
nearest_neighbor_features_per_object(
last_frame_embedding, query_embeddings, prev_frame_labels,
max_neighbors_per_object, k_nearest_neighbors,
gt_ids=all_ref_obj_ids[n]))
prev_frame_nn_features_n = (tf.nn.sigmoid(
prev_frame_nn_features_n) - 0.5) * 2
prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
axis=0)
prev_frame_nn_features_n_tr = tf.transpose(
prev_frame_nn_features_n_sq, [2, 0, 1, 3])
to_concat.append(prev_frame_nn_features_n_tr)
features_n_concat_t = tf.concat(to_concat, axis=-1)
embedding_seg_features_n_t = (
create_embedding_segmentation_features(
features_n_concat_t, embedding_seg_feature_dimension,
embedding_seg_n_layers, embedding_seg_kernel_size,
reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
logits_t = model.get_branch_logits(
embedding_seg_features_n_t,
1,
model_options.atrous_rates,
aspp_with_batch_norm=model_options.aspp_with_batch_norm,
kernel_size=model_options.logits_kernel_size,
weight_decay=weight_decay,
reuse=reuse or n > 0 or t > 0,
scope_suffix=output
)
logits.append(logits_t)
prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
[2, 0, 1])
last_softmax = tf.nn.softmax(logits_t, axis=0)
logits = tf.stack(logits, axis=1)
logits_shape = tf.stack(
[n_objs, num_frames_per_video - n_ref_frames] +
resolve_shape(logits)[2:-1])
logits_reshaped = tf.reshape(logits, logits_shape)
logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
outputs_to_logits[output].append(logits_transposed)
add_image_summaries(
images[n * num_frames_per_video: (n+1) * num_frames_per_video],
nn_features_n,
logits_transposed,
batch_size=1,
prev_frame_nn_features=prev_frame_nn_features_n)
if also_return_embeddings:
return outputs_to_logits, embeddings
else:
return outputs_to_logits
def subsample_reference_embeddings_and_labels(
reference_embeddings_flat, reference_labels_flat, ref_obj_ids,
max_neighbors_per_object):
"""Subsamples the reference embedding vectors and labels.
After subsampling, at most max_neighbors_per_object items will remain per
class.
Args:
reference_embeddings_flat: Tensor of shape [n, embedding_dim],
the embedding vectors for the reference frame.
reference_labels_flat: Tensor of shape [n, 1],
the class labels of the reference frame.
ref_obj_ids: An int32 tensor of the unique object ids present
in the reference labels.
max_neighbors_per_object: Integer, the maximum number of candidates
for the nearest neighbor query per object after subsampling,
or 0 for no subsampling.
Returns:
reference_embeddings_flat: Tensor of shape [n_sub, embedding_dim],
the subsampled embedding vectors for the reference frame.
reference_labels_flat: Tensor of shape [n_sub, 1],
the class labels of the reference frame.
"""
if max_neighbors_per_object == 0:
return reference_embeddings_flat, reference_labels_flat
same_label_mask = tf.equal(reference_labels_flat[tf.newaxis, :],
ref_obj_ids[:, tf.newaxis])
max_neighbors_per_object_repeated = tf.tile(
tf.constant(max_neighbors_per_object)[tf.newaxis],
multiples=[tf.size(ref_obj_ids)])
# Somehow map_fn on GPU caused trouble sometimes, so let's put it on CPU
# for now.
with tf.device('cpu:0'):
subsampled_indices = tf.map_fn(_create_subsampling_mask,
(same_label_mask,
max_neighbors_per_object_repeated),
dtype=tf.int64,
name='subsample_labels_map_fn',
parallel_iterations=1)
mask = tf.not_equal(subsampled_indices, tf.constant(-1, dtype=tf.int64))
masked_indices = tf.boolean_mask(subsampled_indices, mask)
reference_embeddings_flat = tf.gather(reference_embeddings_flat,
masked_indices)
reference_labels_flat = tf.gather(reference_labels_flat, masked_indices)
return reference_embeddings_flat, reference_labels_flat
def _create_subsampling_mask(args):
"""Creates boolean mask which can be used to subsample the labels.
Args:
args: tuple of (label_mask, max_neighbors_per_object), where label_mask
is the mask to be subsampled and max_neighbors_per_object is a int scalar,
the maximum number of neighbors to be retained after subsampling.
Returns:
The boolean mask for subsampling the labels.
"""
label_mask, max_neighbors_per_object = args
indices = tf.squeeze(tf.where(label_mask), axis=1)
shuffled_indices = tf.random_shuffle(indices)
subsampled_indices = shuffled_indices[:max_neighbors_per_object]
n_pad = max_neighbors_per_object - tf.size(subsampled_indices)
padded_label = -1
padding = tf.fill((n_pad,), tf.constant(padded_label, dtype=tf.int64))
padded = tf.concat([subsampled_indices, padding], axis=0)
return padded
def conv2d_identity_initializer(scale=1.0, mean=0, stddev=3e-2):
"""Creates an identity initializer for TensorFlow conv2d.
We add a small amount of normal noise to the initialization matrix.
Code copied from lcchen@.
Args:
scale: The scale coefficient for the identity weight matrix.
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
truncated normal distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
of the truncated normal distribution.
Returns:
An identity initializer function for TensorFlow conv2d.
"""
def _initializer(shape, dtype=tf.float32, partition_info=None):
"""Returns the identity matrix scaled by `scale`.
Args:
shape: A tuple of int32 numbers indicating the shape of the initializing
matrix.
dtype: The data type of the initializing matrix.
partition_info: (Optional) variable_scope._PartitionInfo object holding
additional information about how the variable is partitioned. This input
is not used in our case, but is required by TensorFlow.
Returns:
A identity matrix.
Raises:
ValueError: If len(shape) != 4, or shape[0] != shape[1], or shape[0] is
not odd, or shape[1] is not odd..
"""
del partition_info
if len(shape) != 4:
raise ValueError('Expect shape length to be 4.')
if shape[0] != shape[1]:
raise ValueError('Expect shape[0] = shape[1].')
if shape[0] % 2 != 1:
raise ValueError('Expect shape[0] to be odd value.')
if shape[1] % 2 != 1:
raise ValueError('Expect shape[1] to be odd value.')
weights = np.zeros(shape, dtype=np.float32)
center_y = shape[0] / 2
center_x = shape[1] / 2
min_channel = min(shape[2], shape[3])
for i in range(min_channel):
weights[center_y, center_x, i, i] = scale
return tf.constant(weights, dtype=dtype) + tf.truncated_normal(
shape, mean=mean, stddev=stddev, dtype=dtype)
return _initializer
def split_separable_conv2d_with_identity_initializer(
inputs,
filters,
kernel_size=3,
rate=1,
weight_decay=0.00004,
scope=None):
"""Splits a separable conv2d into depthwise and pointwise conv2d.
This operation differs from `tf.layers.separable_conv2d` as this operation
applies activation function between depthwise and pointwise conv2d.
Args:
inputs: Input tensor with shape [batch, height, width, channels].
filters: Number of filters in the 1x1 pointwise convolution.
kernel_size: A list of length 2: [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
rate: Atrous convolution rate for the depthwise convolution.
weight_decay: The weight decay to use for regularizing the model.
scope: Optional scope for the operation.
Returns:
Computed features after split separable conv2d.
"""
initializer = conv2d_identity_initializer()
outputs = slim.separable_conv2d(
inputs,
None,
kernel_size=kernel_size,
depth_multiplier=1,
rate=rate,
weights_initializer=initializer,
weights_regularizer=None,
scope=scope + '_depthwise')
return slim.conv2d(
outputs,
filters,
1,
weights_initializer=initializer,
weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise')
def create_initial_softmax_from_labels(last_frame_labels, reference_labels,
decoder_output_stride, reduce_labels):
"""Creates initial softmax predictions from last frame labels.
Args:
last_frame_labels: last frame labels of shape [1, height, width, 1].
reference_labels: reference frame labels of shape [1, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder. Can be None, in
this case it's assumed that the last_frame_labels and reference_labels
are already scaled to the decoder output resolution.
reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot
encoding to the actual number of labels present in the reference frame
(otherwise the depth will be the highest label index + 1).
Returns:
init_softmax: the initial softmax predictions.
"""
if decoder_output_stride is None:
labels_output_size = last_frame_labels
reference_labels_output_size = reference_labels
else:
h = tf.shape(last_frame_labels)[1]
w = tf.shape(last_frame_labels)[2]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
labels_output_size = tf.image.resize_nearest_neighbor(
last_frame_labels, [h_sub, w_sub], align_corners=True)
reference_labels_output_size = tf.image.resize_nearest_neighbor(
reference_labels, [h_sub, w_sub], align_corners=True)
if reduce_labels:
unique_labels, _ = tf.unique(tf.reshape(reference_labels_output_size, [-1]))
depth = tf.size(unique_labels)
else:
depth = tf.reduce_max(reference_labels_output_size) + 1
one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size), depth)
with tf.control_dependencies([one_hot_assertion]):
init_softmax = tf.one_hot(tf.squeeze(labels_output_size,
axis=-1),
depth=depth,
dtype=tf.float32)
return init_softmax
def local_previous_frame_nearest_neighbor_features_per_object(
prev_frame_embedding, query_embedding, prev_frame_labels,
gt_ids, max_distance=9):
"""Computes nearest neighbor features while only allowing local matches.
Args:
prev_frame_embedding: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the last frame.
query_embedding: Tensor of shape [height, width, embedding_dim],
the embedding vectors for the query frames.
prev_frame_labels: Tensor of shape [height, width, 1], the class labels of
the previous frame.
gt_ids: Int Tensor of shape [n_objs] of the sorted unique ground truth
ids in the first frame.
max_distance: Integer, the maximum distance allowed for local matching.
Returns:
nn_features: A float32 np.array of nearest neighbor features of shape
[1, height, width, n_objects, 1].
"""
with tf.name_scope(
'local_previous_frame_nearest_neighbor_features_per_object'):
if USE_CORRELATION_COST:
tf.logging.info('Using correlation_cost.')
d = local_pairwise_distances(query_embedding, prev_frame_embedding,
max_distance=max_distance)
else:
# Slow fallback in case correlation_cost is not available.
tf.logging.warn('correlation cost is not available, using slow fallback '
'implementation.')
d = local_pairwise_distances2(query_embedding, prev_frame_embedding,
max_distance=max_distance)
d = (tf.nn.sigmoid(d) - 0.5) * 2
height = tf.shape(prev_frame_embedding)[0]
width = tf.shape(prev_frame_embedding)[1]
# Create offset versions of the mask.
if USE_CORRELATION_COST:
# New, faster code with cross-correlation via correlation_cost.
# Due to padding we have to add 1 to the labels.
offset_labels = correlation_cost_op.correlation_cost(
tf.ones((1, height, width, 1)),
tf.cast(prev_frame_labels + 1, tf.float32)[tf.newaxis],
kernel_size=1,
max_displacement=max_distance, stride_1=1, stride_2=1,
pad=max_distance)
offset_labels = tf.squeeze(offset_labels, axis=0)[..., tf.newaxis]
# Subtract the 1 again and round.
offset_labels = tf.round(offset_labels - 1)
offset_masks = tf.equal(
offset_labels,
tf.cast(gt_ids, tf.float32)[tf.newaxis, tf.newaxis, tf.newaxis, :])
else:
# Slower code, without dependency to correlation_cost
masks = tf.equal(prev_frame_labels, gt_ids[tf.newaxis, tf.newaxis])
padded_masks = tf.pad(masks,
[[max_distance, max_distance],
[max_distance, max_distance],
[0, 0]])
offset_masks = []
for y_start in range(2 * max_distance + 1):
y_end = y_start + height
masks_slice = padded_masks[y_start:y_end]
for x_start in range(2 * max_distance + 1):
x_end = x_start + width
offset_mask = masks_slice[:, x_start:x_end]
offset_masks.append(offset_mask)
offset_masks = tf.stack(offset_masks, axis=2)
pad = tf.ones((height, width, (2 * max_distance + 1) ** 2, tf.size(gt_ids)))
d_tiled = tf.tile(d[..., tf.newaxis], multiples=(1, 1, 1, tf.size(gt_ids)))
d_masked = tf.where(offset_masks, d_tiled, pad)
dists = tf.reduce_min(d_masked, axis=2)
dists = tf.reshape(dists, (1, height, width, tf.size(gt_ids), 1))
return dists
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for embedding utils."""
import unittest
import numpy as np
import tensorflow as tf
from feelvos.utils import embedding_utils
if embedding_utils.USE_CORRELATION_COST:
# pylint: disable=g-import-not-at-top
from correlation_cost.python.ops import correlation_cost_op
class EmbeddingUtilsTest(tf.test.TestCase):
def test_pairwise_distances(self):
x = np.arange(100, dtype=np.float32).reshape(20, 5)
y = np.arange(100, 200, dtype=np.float32).reshape(20, 5)
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
x = tf.constant(x)
y = tf.constant(y)
d1 = embedding_utils.pairwise_distances(x, y)
d2 = embedding_utils.pairwise_distances2(x, y)
d1_val, d2_val = sess.run([d1, d2])
self.assertAllClose(d1_val, d2_val)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_correlation_cost_one_dimensional(self):
a = np.array([[[[1.0], [2.0]], [[3.0], [4.0]]]])
b = np.array([[[[2.0], [1.0]], [[4.0], [3.0]]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
c = correlation_cost_op.correlation_cost(
a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
pad=1)
c = tf.squeeze(c, axis=0)
c_val = sess.run(c)
self.assertAllEqual(c_val.shape, (2, 2, 9))
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[0, y, x, 0]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
b_slice = 0
else:
b_slice = b[0, y + dy, x + dx, 0]
expected = a_slice * b_slice
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_correlation_cost_two_dimensional(self):
a = np.array([[[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]]])
b = np.array([[[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
c = correlation_cost_op.correlation_cost(
a, b, kernel_size=1, max_displacement=1, stride_1=1, stride_2=1,
pad=1)
c = tf.squeeze(c, axis=0)
c_val = sess.run(c)
self.assertAllEqual(c_val.shape, (2, 2, 9))
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[0, y, x, :]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
b_slice = 0
else:
b_slice = b[0, y + dy, x + dx, :]
expected = (a_slice * b_slice).mean()
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(c_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_one_dimensional(self):
a = np.array([[[1.0], [2.0]], [[3.0], [4.0]]])
b = np.array([[[2.0], [1.0]], [[4.0], [3.0]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
max_distance=1)
d_val = sess.run(d)
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[y, x, 0]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
expected = np.float('inf')
else:
b_slice = b[y + dy, x + dx, 0]
expected = (a_slice - b_slice) ** 2
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_shape(self):
a = np.zeros((4, 5, 2))
b = np.zeros((4, 5, 2))
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf, max_distance=4)
d_val = sess.run(d)
self.assertAllEqual(d_val.shape, (4, 5, 81))
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_pairwise_distances_two_dimensional(self):
a = np.array([[[1.0, -5.0], [7.0, 2.0]], [[1.0, 3.0], [3.0, 4.0]]])
b = np.array([[[2.0, 1.0], [0.0, -9.0]], [[4.0, 3.0], [3.0, 1.0]]])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
a_tf = tf.constant(a, dtype=tf.float32)
b_tf = tf.constant(b, dtype=tf.float32)
d = embedding_utils.local_pairwise_distances(a_tf, b_tf,
max_distance=1)
d_val = sess.run(d)
for y in range(2):
for x in range(2):
for dy in range(-1, 2):
for dx in range(-1, 2):
a_slice = a[y, x, :]
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
expected = np.float('inf')
else:
b_slice = b[y + dy, x + dx, :]
expected = ((a_slice - b_slice) ** 2).sum()
dy0 = dy + 1
dx0 = dx + 1
self.assertAlmostEqual(d_val[y, x, 3 * dy0 + dx0], expected)
@unittest.skipIf(not embedding_utils.USE_CORRELATION_COST,
'depends on correlation_cost')
def test_local_previous_frame_nearest_neighbor_features_per_object(self):
prev_frame_embedding = np.array([[[1.0, -5.0], [7.0, 2.0]],
[[1.0, 3.0], [3.0, 4.0]]]) / 10
query_embedding = np.array([[[2.0, 1.0], [0.0, -9.0]],
[[4.0, 3.0], [3.0, 1.0]]]) / 10
prev_frame_labels = np.array([[[0], [1]], [[1], [0]]])
gt_ids = np.array([0, 1])
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g) as sess:
prev_frame_embedding_tf = tf.constant(prev_frame_embedding,
dtype=tf.float32)
query_embedding_tf = tf.constant(query_embedding, dtype=tf.float32)
embu = embedding_utils
dists = (
embu.local_previous_frame_nearest_neighbor_features_per_object(
prev_frame_embedding_tf, query_embedding_tf,
prev_frame_labels, gt_ids, max_distance=1))
dists = tf.squeeze(dists, axis=4)
dists = tf.squeeze(dists, axis=0)
dists_val = sess.run(dists)
for obj_id in gt_ids:
for y in range(2):
for x in range(2):
curr_min = 1.0
for dy in range(-1, 2):
for dx in range(-1, 2):
# Attention: here we shift the prev frame embedding,
# not the query.
if y + dy < 0 or y + dy > 1 or x + dx < 0 or x + dx > 1:
continue
if prev_frame_labels[y + dy, x + dx, 0] != obj_id:
continue
prev_frame_slice = prev_frame_embedding[y + dy, x + dx, :]
query_frame_slice = query_embedding[y, x, :]
v_unnorm = ((prev_frame_slice - query_frame_slice) ** 2).sum()
v = ((1.0 / (1.0 + np.exp(-v_unnorm))) - 0.5) * 2
curr_min = min(curr_min, v)
expected = curr_min
self.assertAlmostEqual(dists_val[y, x, obj_id], expected,
places=5)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for evaluations."""
import numpy as np
import PIL
import tensorflow as tf
pascal_colormap = [
0, 0, 0,
0.5020, 0, 0,
0, 0.5020, 0,
0.5020, 0.5020, 0,
0, 0, 0.5020,
0.5020, 0, 0.5020,
0, 0.5020, 0.5020,
0.5020, 0.5020, 0.5020,
0.2510, 0, 0,
0.7529, 0, 0,
0.2510, 0.5020, 0,
0.7529, 0.5020, 0,
0.2510, 0, 0.5020,
0.7529, 0, 0.5020,
0.2510, 0.5020, 0.5020,
0.7529, 0.5020, 0.5020,
0, 0.2510, 0,
0.5020, 0.2510, 0,
0, 0.7529, 0,
0.5020, 0.7529, 0,
0, 0.2510, 0.5020,
0.5020, 0.2510, 0.5020,
0, 0.7529, 0.5020,
0.5020, 0.7529, 0.5020,
0.2510, 0.2510, 0]
def save_segmentation_with_colormap(filename, img):
"""Saves a segmentation with the pascal colormap as expected for DAVIS eval.
Args:
filename: Where to store the segmentation.
img: A numpy array of the segmentation to be saved.
"""
if img.shape[-1] == 1:
img = img[..., 0]
# Save with colormap.
colormap = (np.array(pascal_colormap) * 255).round().astype('uint8')
colormap_image = PIL.Image.new('P', (16, 16))
colormap_image.putpalette(colormap)
pil_image = PIL.Image.fromarray(img.astype('uint8'))
pil_image_with_colormap = pil_image.quantize(palette=colormap_image)
with tf.gfile.GFile(filename, 'w') as f:
pil_image_with_colormap.save(f)
def save_embeddings(filename, embeddings):
with tf.gfile.GFile(filename, 'w') as f:
np.save(f, embeddings)
def calculate_iou(pred_labels, ref_labels):
"""Calculates the intersection over union for binary segmentation.
Args:
pred_labels: predicted segmentation labels.
ref_labels: reference segmentation labels.
Returns:
The IoU between pred_labels and ref_labels
"""
if ref_labels.any():
i = np.logical_and(pred_labels, ref_labels).sum()
u = np.logical_or(pred_labels, ref_labels).sum()
return i.astype('float') / u
else:
if pred_labels.any():
return 0.0
else:
return 1.0
def calculate_multi_object_miou_tf(pred_labels, ref_labels):
"""Calculates the mIoU for a batch of predicted and reference labels.
Args:
pred_labels: Int32 tensor of shape [batch, height, width, 1].
ref_labels: Int32 tensor of shape [batch, height, width, 1].
Returns:
The mIoU between pred_labels and ref_labels as float32 scalar tensor.
"""
def calculate_multi_object_miou(pred_labels_, ref_labels_):
"""Calculates the mIoU for predicted and reference labels in numpy.
Args:
pred_labels_: int32 np.array of shape [batch, height, width, 1].
ref_labels_: int32 np.array of shape [batch, height, width, 1].
Returns:
The mIoU between pred_labels_ and ref_labels_.
"""
assert len(pred_labels_.shape) == 4
assert pred_labels_.shape[3] == 1
assert pred_labels_.shape == ref_labels_.shape
ious = []
for pred_label, ref_label in zip(pred_labels_, ref_labels_):
ids = np.setdiff1d(np.unique(ref_label), [0])
if ids.size == 0:
continue
for id_ in ids:
iou = calculate_iou(pred_label == id_, ref_label == id_)
ious.append(iou)
if ious:
return np.cast['float32'](np.mean(ious))
else:
return np.cast['float32'](1.0)
miou = tf.py_func(calculate_multi_object_miou, [pred_labels, ref_labels],
tf.float32, name='calculate_multi_object_miou')
miou.set_shape(())
return miou
def calculate_multi_object_ious(pred_labels, ref_labels, label_set):
"""Calculates the intersection over union for binary segmentation.
Args:
pred_labels: predicted segmentation labels.
ref_labels: reference segmentation labels.
label_set: int np.array of object ids.
Returns:
float np.array of IoUs between pred_labels and ref_labels
for each object in label_set.
"""
# Background should not be included as object label.
return np.array([calculate_iou(pred_labels == label, ref_labels == label)
for label in label_set if label != 0])
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for artificially damaging segmentation masks."""
import numpy as np
from scipy.ndimage import interpolation
from skimage import morphology
from skimage import transform
import tensorflow as tf
def damage_masks(labels, shift=True, scale=True, rotate=True, dilate=True):
"""Damages segmentation masks by random transformations.
Args:
labels: Int32 labels tensor of shape (height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of labels.
"""
def _damage_masks_np(labels_):
return damage_masks_np(labels_, shift, scale, rotate, dilate)
damaged_masks = tf.py_func(_damage_masks_np, [labels], tf.int32,
name='damage_masks')
damaged_masks.set_shape(labels.get_shape())
return damaged_masks
def damage_masks_np(labels, shift=True, scale=True, rotate=True, dilate=True):
"""Performs the actual mask damaging in numpy.
Args:
labels: Int32 numpy array of shape (height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of labels.
"""
unique_labels = np.unique(labels)
unique_labels = np.setdiff1d(unique_labels, [0])
# Shuffle to get random depth ordering when combining together.
np.random.shuffle(unique_labels)
damaged_labels = np.zeros_like(labels)
for l in unique_labels:
obj_mask = (labels == l)
damaged_obj_mask = _damage_single_object_mask(obj_mask, shift, scale,
rotate, dilate)
damaged_labels[damaged_obj_mask] = l
return damaged_labels
def _damage_single_object_mask(mask, shift, scale, rotate, dilate):
"""Performs mask damaging in numpy for a single object.
Args:
mask: Boolean numpy array of shape(height, width, 1).
shift: Boolean, whether to damage the masks by shifting.
scale: Boolean, whether to damage the masks by scaling.
rotate: Boolean, whether to damage the masks by rotation.
dilate: Boolean, whether to damage the masks by dilation.
Returns:
The damaged version of mask.
"""
# For now we just do shifting and scaling. Better would be Affine or thin
# spline plate transformations.
if shift:
mask = _shift_mask(mask)
if scale:
mask = _scale_mask(mask)
if rotate:
mask = _rotate_mask(mask)
if dilate:
mask = _dilate_mask(mask)
return mask
def _shift_mask(mask, max_shift_factor=0.05):
"""Damages a mask for a single object by randomly shifting it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
max_shift_factor: Float scalar, the maximum factor for random shifting.
Returns:
The shifted version of mask.
"""
nzy, nzx, _ = mask.nonzero()
h = nzy.max() - nzy.min()
w = nzx.max() - nzx.min()
size = np.sqrt(h * w)
offset = np.random.uniform(-size * max_shift_factor, size * max_shift_factor,
2)
shifted_mask = interpolation.shift(np.squeeze(mask, axis=2),
offset, order=0).astype('bool')[...,
np.newaxis]
return shifted_mask
def _scale_mask(mask, scale_amount=0.025):
"""Damages a mask for a single object by randomly scaling it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
scale_amount: Float scalar, the maximum factor for random scaling.
Returns:
The scaled version of mask.
"""
nzy, nzx, _ = mask.nonzero()
cy = 0.5 * (nzy.max() - nzy.min())
cx = 0.5 * (nzx.max() - nzx.min())
scale_factor = np.random.uniform(1.0 - scale_amount, 1.0 + scale_amount)
shift = transform.SimilarityTransform(translation=[-cx, -cy])
inv_shift = transform.SimilarityTransform(translation=[cx, cy])
s = transform.SimilarityTransform(scale=[scale_factor, scale_factor])
m = (shift + (s + inv_shift)).inverse
scaled_mask = transform.warp(mask, m) > 0.5
return scaled_mask
def _rotate_mask(mask, max_rot_degrees=3.0):
"""Damages a mask for a single object by randomly rotating it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
max_rot_degrees: Float scalar, the maximum number of degrees to rotate.
Returns:
The scaled version of mask.
"""
cy = 0.5 * mask.shape[0]
cx = 0.5 * mask.shape[1]
rot_degrees = np.random.uniform(-max_rot_degrees, max_rot_degrees)
shift = transform.SimilarityTransform(translation=[-cx, -cy])
inv_shift = transform.SimilarityTransform(translation=[cx, cy])
r = transform.SimilarityTransform(rotation=np.deg2rad(rot_degrees))
m = (shift + (r + inv_shift)).inverse
scaled_mask = transform.warp(mask, m) > 0.5
return scaled_mask
def _dilate_mask(mask, dilation_radius=5):
"""Damages a mask for a single object by dilating it in numpy.
Args:
mask: Boolean numpy array of shape(height, width, 1).
dilation_radius: Integer, the radius of the used disk structure element.
Returns:
The dilated version of mask.
"""
disk = morphology.disk(dilation_radius, dtype=np.bool)
dilated_mask = morphology.binary_dilation(
np.squeeze(mask, axis=2), selem=disk)[..., np.newaxis]
return dilated_mask
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for training."""
import collections
import six
import tensorflow as tf
from deeplab.core import preprocess_utils
from deeplab.utils import train_utils
from feelvos.utils import embedding_utils
from feelvos.utils import eval_utils
slim = tf.contrib.slim
add_softmax_cross_entropy_loss_for_each_scale = (
train_utils.add_softmax_cross_entropy_loss_for_each_scale)
get_model_gradient_multipliers = train_utils.get_model_gradient_multipliers
get_model_learning_rate = train_utils.get_model_learning_rate
resolve_shape = preprocess_utils.resolve_shape
def add_triplet_loss_for_each_scale(batch_size, num_frames_per_video,
embedding_dim, scales_to_embeddings,
labels, scope):
"""Adds triplet loss for logits of each scale.
Args:
batch_size: Int, the number of video chunks sampled per batch
num_frames_per_video: Int, the number of frames per video.
embedding_dim: Int, the dimension of the learned embedding
scales_to_embeddings: A map from embedding names for different scales to
embeddings. The embeddings have shape [batch, embeddings_height,
embeddings_width, embedding_dim].
labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
scope: String, the scope for the loss.
Raises:
ValueError: labels is None.
"""
if labels is None:
raise ValueError('No label for triplet loss.')
for scale, embeddings in scales_to_embeddings.iteritems():
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
# Label is downsampled to the same size as logits.
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
resolve_shape(embeddings, 4)[1:3],
align_corners=True)
# Reshape from [batch * num_frames, ...] to [batch, num_frames, ...].
h = tf.shape(embeddings)[1]
w = tf.shape(embeddings)[2]
new_labels_shape = tf.stack([batch_size, num_frames_per_video, h, w, 1])
reshaped_labels = tf.reshape(scaled_labels, new_labels_shape)
new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w,
-1])
reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
with tf.name_scope(loss_scope):
total_loss = tf.constant(0, dtype=tf.float32)
for n in range(batch_size):
embedding = reshaped_embeddings[n]
label = reshaped_labels[n]
n_pixels = h * w
n_anchors_used = 256
sampled_anchor_indices = tf.random_shuffle(tf.range(n_pixels))[
:n_anchors_used]
anchors_pool = tf.reshape(embedding[0], [-1, embedding_dim])
anchors_pool_classes = tf.reshape(label[0], [-1])
anchors = tf.gather(anchors_pool, sampled_anchor_indices)
anchor_classes = tf.gather(anchors_pool_classes, sampled_anchor_indices)
pos_neg_pool = tf.reshape(embedding[1:], [-1, embedding_dim])
pos_neg_pool_classes = tf.reshape(label[1:], [-1])
dists = embedding_utils.pairwise_distances(anchors, pos_neg_pool)
pos_mask = tf.equal(anchor_classes[:, tf.newaxis],
pos_neg_pool_classes[tf.newaxis, :])
neg_mask = tf.logical_not(pos_mask)
pos_mask_f = tf.cast(pos_mask, tf.float32)
neg_mask_f = tf.cast(neg_mask, tf.float32)
pos_dists = pos_mask_f * dists + 1e20 * neg_mask_f
neg_dists = neg_mask_f * dists + 1e20 * pos_mask_f
pos_dists_min = tf.reduce_min(pos_dists, axis=1)
neg_dists_min = tf.reduce_min(neg_dists, axis=1)
margin = 1.0
loss = tf.nn.relu(pos_dists_min - neg_dists_min + margin)
# Handle case that no positive is present (per anchor).
any_pos = tf.reduce_any(pos_mask, axis=1)
loss *= tf.cast(any_pos, tf.float32)
# Average over anchors
loss = tf.reduce_mean(loss, axis=0)
total_loss += loss
total_loss /= batch_size
# Scale the loss up a bit.
total_loss *= 3.0
tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
def add_dynamic_softmax_cross_entropy_loss_for_each_scale(
scales_to_logits, labels, ignore_label, loss_weight=1.0,
upsample_logits=True, scope=None, top_k_percent_pixels=1.0,
hard_example_mining_step=100000):
"""Adds softmax cross entropy loss per scale for logits with varying classes.
Also adds summaries for mIoU.
Args:
scales_to_logits: A map from logits names for different scales to logits.
The logits are a list of length batch_size of tensors of shape
[time, logits_height, logits_width, num_classes].
labels: Groundtruth labels with shape [batch_size * time, image_height,
image_width, 1].
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not.
scope: String, the scope for the loss.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1.0, only compute the loss for the top k percent pixels (e.g.,
the top 20% pixels). This is useful for hard pixel mining.
hard_example_mining_step: An integer, the training step in which the
hard exampling mining kicks off. Note that we gradually reduce the
mining percent to the top_k_percent_pixels. For example, if
hard_example_mining_step=100K and top_k_percent_pixels=0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
Raises:
ValueError: Label or logits is None.
"""
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
if top_k_percent_pixels < 0 or top_k_percent_pixels > 1:
raise ValueError('Unexpected value of top_k_percent_pixels.')
for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
if upsample_logits:
# Label is not downsampled, and instead we upsample logits.
assert isinstance(logits, collections.Sequence)
logits = [tf.image.resize_bilinear(
x,
preprocess_utils.resolve_shape(labels, 4)[1:3],
align_corners=True) for x in logits]
scaled_labels = labels
else:
# Label is downsampled to the same size as logits.
assert isinstance(logits, collections.Sequence)
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
preprocess_utils.resolve_shape(logits[0], 4)[1:3],
align_corners=True)
batch_size = len(logits)
num_time = preprocess_utils.resolve_shape(logits[0])[0]
reshaped_labels = tf.reshape(
scaled_labels, ([batch_size, num_time] +
preprocess_utils.resolve_shape(scaled_labels)[1:]))
for n, logits_n in enumerate(logits):
labels_n = reshaped_labels[n]
labels_n = tf.reshape(labels_n, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(labels_n,
ignore_label)) * loss_weight
num_classes_n = tf.shape(logits_n)[-1]
one_hot_labels = slim.one_hot_encoding(
labels_n, num_classes_n, on_value=1.0, off_value=0.0)
logits_n_flat = tf.reshape(logits_n, shape=[-1, num_classes_n])
if top_k_percent_pixels == 1.0:
tf.losses.softmax_cross_entropy(
one_hot_labels,
logits_n_flat,
weights=not_ignore_mask,
scope=loss_scope)
else:
# Only compute the loss for top k percent pixels.
# First, compute the loss for all pixels. Note we do not put the loss
# to loss_collection and set reduction = None to keep the shape.
num_pixels = tf.to_float(tf.shape(logits_n_flat)[0])
pixel_losses = tf.losses.softmax_cross_entropy(
one_hot_labels,
logits_n_flat,
weights=not_ignore_mask,
scope='pixel_losses',
loss_collection=None,
reduction=tf.losses.Reduction.NONE)
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
_, top_k_indices = tf.nn.top_k(pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
# Compute the loss for the top k percent pixels.
tf.losses.softmax_cross_entropy(
tf.gather(one_hot_labels, top_k_indices),
tf.gather(logits_n_flat, top_k_indices),
weights=tf.gather(not_ignore_mask, top_k_indices),
scope=loss_scope)
pred_n = tf.argmax(logits_n, axis=-1, output_type=tf.int32)[
..., tf.newaxis]
labels_n = labels[n * num_time: (n + 1) * num_time]
miou = eval_utils.calculate_multi_object_miou_tf(pred_n, labels_n)
tf.summary.scalar('miou', miou)
def get_model_init_fn(train_logdir,
tf_initial_checkpoint,
initialize_last_layer,
last_layers,
ignore_missing_vars=False):
"""Gets the function initializing model variables from a checkpoint.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
if variables_to_restore:
return slim.assign_from_checkpoint_fn(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
return None
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation video data."""
import tensorflow as tf
from feelvos import input_preprocess
from feelvos import model
from feelvos.utils import mask_damaging
from feelvos.utils import train_utils
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
MIN_LABEL_COUNT = 10
def decode_image_sequence(tensor, image_format='jpeg', shape=None,
channels=3, raw_dtype=tf.uint8):
"""Decodes a sequence of images.
Args:
tensor: the tensor of strings to decode, shape: [num_images]
image_format: a string (possibly tensor) with the format of the image.
Options include 'jpeg', 'png', and 'raw'.
shape: a list or tensor of the decoded image shape for a single image.
channels: if 'shape' is None, the third dimension of the image is set to
this value.
raw_dtype: if the image is encoded as raw bytes, this is the method of
decoding the bytes into values.
Returns:
The decoded images with shape [time, height, width, channels].
"""
handler = slim.tfexample_decoder.Image(
shape=shape, channels=channels, dtype=raw_dtype, repeated=True)
return handler.tensors_to_item({'image/encoded': tensor,
'image/format': image_format})
def _get_data(data_provider, dataset_split, video_frames_are_decoded):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
video_frames_are_decoded: Boolean, whether the video frames are already
decoded
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
object_label: An integer refers to object_label according to labelmap. If
the example has more than one object_label, take the first one.
image_name: Image name.
height: Image height.
width: Image width.
video_id: String tensor representing the name of the video.
Raises:
ValueError: Failed to find label.
"""
if video_frames_are_decoded:
image, = data_provider.get(['image'])
else:
image, = data_provider.get(['image/encoded'])
# Some datasets do not contain image_name.
if 'image_name' in data_provider.list_items():
image_name, = data_provider.get(['image_name'])
else:
image_name = tf.constant('')
height, width = data_provider.get(['height', 'width'])
label = None
if dataset_split != 'test':
if video_frames_are_decoded:
if 'labels_class' not in data_provider.list_items():
raise ValueError('Failed to find labels.')
label, = data_provider.get(['labels_class'])
else:
key = 'segmentation/object/encoded'
if key not in data_provider.list_items():
raise ValueError('Failed to find labels.')
label, = data_provider.get([key])
object_label = None
video_id, = data_provider.get(['video_id'])
return image, label, object_label, image_name, height, width, video_id
def _has_foreground_and_background_in_first_frame(label, subsampling_factor):
"""Checks if the labels have foreground and background in the first frame.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
subsampling_factor: Integer, the subsampling factor.
Returns:
Boolean, whether the labels have foreground and background in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis],
[h // subsampling_factor,
w // subsampling_factor],
align_corners=True),
axis=0)
is_bg = tf.equal(label_downscaled, 0)
is_fg = tf.logical_not(is_bg)
# Just using reduce_any was not robust enough, so lets make sure the count
# is above MIN_LABEL_COUNT.
fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
return tf.logical_and(has_bg, has_fg)
def _has_foreground_and_background_in_first_frame_2(label,
decoder_output_stride):
"""Checks if the labels have foreground and background in the first frame.
Second attempt, this time we use the actual output dimension for resizing.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder output.
Returns:
Boolean, whether the labels have foreground and background in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
align_corners=True), axis=0)
is_bg = tf.equal(label_downscaled, 0)
is_fg = tf.logical_not(is_bg)
# Just using reduce_any was not robust enough, so lets make sure the count
# is above MIN_LABEL_COUNT.
fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
return tf.logical_and(has_bg, has_fg)
def _has_enough_pixels_of_each_object_in_first_frame(
label, decoder_output_stride):
"""Checks if for each object (incl. background) enough pixels are visible.
During test time, we will usually not see a reference frame in which only
very few pixels of one object are visible. These cases can be problematic
during training, especially if more than the 1-nearest neighbor is used.
That's why this function can be used to detect and filter these cases.
Args:
label: Label tensor of shape [num_frames, height, width, 1].
decoder_output_stride: Integer, the stride of the decoder output.
Returns:
Boolean, whether the labels have enough pixels of each object in the first
frame.
"""
h, w = train_utils.resolve_shape(label)[1:3]
h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
label_downscaled = tf.squeeze(
tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
align_corners=True), axis=0)
_, _, counts = tf.unique_with_counts(
tf.reshape(label_downscaled, [-1]))
has_enough_pixels_per_object = tf.reduce_all(
tf.greater_equal(counts, MIN_LABEL_COUNT))
return has_enough_pixels_per_object
def get(dataset,
num_frames_per_video,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
preprocess_image_and_label=True,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None,
batch_capacity_factor=32,
video_frames_are_decoded=False,
decoder_output_stride=None,
first_frame_finetuning=False,
sample_only_first_frame_for_finetuning=False,
sample_adjacent_and_consistent_query_frames=False,
remap_labels_to_reference_frame=True,
generate_prev_frame_mask_by_mask_damaging=False,
three_frame_dataset=False,
add_prev_frame_label=True):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
num_frames_per_video: The number of frames used per video
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
preprocess_image_and_label: Boolean variable specifies if preprocessing of
image and label will be performed or not.
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
batch_capacity_factor: Batch capacity factor affecting the training queue
batch capacity.
video_frames_are_decoded: Boolean, whether the video frames are already
decoded
decoder_output_stride: Integer, the stride of the decoder output.
first_frame_finetuning: Boolean, whether to only sample the first frame
for fine-tuning.
sample_only_first_frame_for_finetuning: Boolean, whether to only sample the
first frame during fine-tuning. This should be False when using lucid or
wonderland data, but true when fine-tuning on the first frame only.
Only has an effect if first_frame_finetuning is True.
sample_adjacent_and_consistent_query_frames: Boolean, if true, the query
frames (all but the first frame which is the reference frame) will be
sampled such that they are adjacent video frames and have the same
crop coordinates and flip augmentation.
remap_labels_to_reference_frame: Boolean, whether to remap the labels of
the query frames to match the labels of the (downscaled) reference frame.
If a query frame contains a label which is not present in the reference,
it will be mapped to background.
generate_prev_frame_mask_by_mask_damaging: Boolean, whether to generate
the masks used as guidance from the previous frame by damaging the
ground truth mask.
three_frame_dataset: Boolean, whether the dataset has exactly three frames
per video of which the first is to be used as reference and the two
others are consecutive frames to be used as query frames.
add_prev_frame_label: Boolean, whether to sample one more frame before the
first query frame to obtain a previous frame label. Only has an effect,
if sample_adjacent_and_consistent_query_frames is True and
generate_prev_frame_mask_by_mask_damaging is False.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, or Failed to find labels.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, object_label, image_name, height, width, video_id = _get_data(
data_provider, dataset_split, video_frames_are_decoded)
sampling_is_valid = tf.constant(True)
if num_frames_per_video is not None:
total_num_frames = tf.shape(image)[0]
if first_frame_finetuning or three_frame_dataset:
if sample_only_first_frame_for_finetuning:
assert not sample_adjacent_and_consistent_query_frames, (
'this option does not make sense for sampling only first frame.')
# Sample the first frame num_frames_per_video times.
sel_indices = tf.tile(tf.constant(0, dtype=tf.int32)[tf.newaxis],
multiples=[num_frames_per_video])
else:
if sample_adjacent_and_consistent_query_frames:
if add_prev_frame_label:
num_frames_per_video += 1
# Since this is first frame fine-tuning, we'll for now assume that
# each sequence has exactly 3 images: the ref frame and 2 adjacent
# query frames.
assert num_frames_per_video == 3
with tf.control_dependencies([tf.assert_equal(total_num_frames, 3)]):
sel_indices = tf.constant([1, 2], dtype=tf.int32)
else:
# Sample num_frames_per_video - 1 query frames which are not the
# first frame.
sel_indices = tf.random_shuffle(
tf.range(1, total_num_frames))[:(num_frames_per_video - 1)]
# Concat first frame as reference frame to the front.
sel_indices = tf.concat([tf.constant(0, dtype=tf.int32)[tf.newaxis],
sel_indices], axis=0)
else:
if sample_adjacent_and_consistent_query_frames:
if add_prev_frame_label:
# Sample one more frame which we can use to provide initial softmax
# feedback.
num_frames_per_video += 1
ref_idx = tf.random_shuffle(tf.range(total_num_frames))[0]
sampling_is_valid = tf.greater_equal(total_num_frames,
num_frames_per_video)
def sample_query_start_idx():
return tf.random_shuffle(
tf.range(total_num_frames - num_frames_per_video + 1))[0]
query_start_idx = tf.cond(sampling_is_valid, sample_query_start_idx,
lambda: tf.constant(0, dtype=tf.int32))
def sample_sel_indices():
return tf.concat(
[ref_idx[tf.newaxis],
tf.range(
query_start_idx,
query_start_idx + (num_frames_per_video - 1))], axis=0)
sel_indices = tf.cond(
sampling_is_valid, sample_sel_indices,
lambda: tf.zeros((num_frames_per_video,), dtype=tf.int32))
else:
# Randomly sample some frames from the video.
sel_indices = tf.random_shuffle(
tf.range(total_num_frames))[:num_frames_per_video]
image = tf.gather(image, sel_indices, axis=0)
if not video_frames_are_decoded:
image = decode_image_sequence(image)
if label is not None:
if num_frames_per_video is not None:
label = tf.gather(label, sel_indices, axis=0)
if not video_frames_are_decoded:
label = decode_image_sequence(label, image_format='png', channels=1)
# Sometimes, label is saved as [num_frames_per_video, height, width] or
# [num_frames_per_video, height, width, 1]. We change it to be
# [num_frames_per_video, height, width, 1].
if label.shape.ndims == 3:
label = tf.expand_dims(label, 3)
elif label.shape.ndims == 4 and label.shape.dims[3] == 1:
pass
else:
raise ValueError('Input label shape must be '
'[num_frames_per_video, height, width],'
' or [num_frames, height, width, 1]. '
'Got {}'.format(label.shape.ndims))
label.set_shape([None, None, None, 1])
# Add size of first dimension since tf can't figure it out automatically.
image.set_shape((num_frames_per_video, None, None, None))
if label is not None:
label.set_shape((num_frames_per_video, None, None, None))
preceding_frame_label = None
if preprocess_image_and_label:
if num_frames_per_video is None:
raise ValueError('num_frame_per_video must be specified for preproc.')
original_images = []
images = []
labels = []
if sample_adjacent_and_consistent_query_frames:
num_frames_individual_preproc = 1
else:
num_frames_individual_preproc = num_frames_per_video
for frame_idx in range(num_frames_individual_preproc):
original_image_t, image_t, label_t = (
input_preprocess.preprocess_image_and_label(
image[frame_idx],
label[frame_idx],
crop_height=crop_size[0] if crop_size is not None else None,
crop_width=crop_size[1] if crop_size is not None else None,
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant))
original_images.append(original_image_t)
images.append(image_t)
labels.append(label_t)
if sample_adjacent_and_consistent_query_frames:
imgs_for_preproc = [image[frame_idx] for frame_idx in
range(1, num_frames_per_video)]
labels_for_preproc = [label[frame_idx] for frame_idx in
range(1, num_frames_per_video)]
original_image_rest, image_rest, label_rest = (
input_preprocess.preprocess_images_and_labels_consistently(
imgs_for_preproc,
labels_for_preproc,
crop_height=crop_size[0] if crop_size is not None else None,
crop_width=crop_size[1] if crop_size is not None else None,
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant))
original_images.extend(original_image_rest)
images.extend(image_rest)
labels.extend(label_rest)
assert len(original_images) == num_frames_per_video
assert len(images) == num_frames_per_video
assert len(labels) == num_frames_per_video
if remap_labels_to_reference_frame:
# Remap labels to indices into the labels of the (downscaled) reference
# frame, or 0, i.e. background, for labels which are not present
# in the reference.
reference_labels = labels[0][tf.newaxis]
h, w = train_utils.resolve_shape(reference_labels)[1:3]
embedding_height = model.scale_dimension(
h, 1.0 / decoder_output_stride)
embedding_width = model.scale_dimension(
w, 1.0 / decoder_output_stride)
reference_labels_embedding_size = tf.squeeze(
tf.image.resize_nearest_neighbor(
reference_labels, tf.stack([embedding_height, embedding_width]),
align_corners=True),
axis=0)
# Get sorted unique labels in the reference frame.
labels_in_ref_frame, _ = tf.unique(
tf.reshape(reference_labels_embedding_size, [-1]))
labels_in_ref_frame = tf.contrib.framework.sort(labels_in_ref_frame)
for idx in range(1, len(labels)):
ref_label_mask = tf.equal(
labels[idx],
labels_in_ref_frame[tf.newaxis, tf.newaxis, :])
remapped = tf.argmax(tf.cast(ref_label_mask, tf.uint8), axis=-1,
output_type=tf.int32)
# Set to 0 if label is not present
is_in_ref = tf.reduce_any(ref_label_mask, axis=-1)
remapped *= tf.cast(is_in_ref, tf.int32)
labels[idx] = remapped[..., tf.newaxis]
if sample_adjacent_and_consistent_query_frames:
if first_frame_finetuning and generate_prev_frame_mask_by_mask_damaging:
preceding_frame_label = mask_damaging.damage_masks(labels[1])
elif add_prev_frame_label:
# Discard the image of the additional frame and take the label as
# initialization for softmax feedback.
original_images = [original_images[0]] + original_images[2:]
preceding_frame_label = labels[1]
images = [images[0]] + images[2:]
labels = [labels[0]] + labels[2:]
num_frames_per_video -= 1
original_image = tf.stack(original_images, axis=0)
image = tf.stack(images, axis=0)
label = tf.stack(labels, axis=0)
else:
if label is not None:
# Need to set label shape due to batching.
label.set_shape([num_frames_per_video,
None if crop_size is None else crop_size[0],
None if crop_size is None else crop_size[1],
1])
original_image = tf.to_float(tf.zeros_like(label))
if crop_size is None:
height = tf.shape(image)[1]
width = tf.shape(image)[2]
else:
height = crop_size[0]
width = crop_size[1]
sample = {'image': image,
'image_name': image_name,
'height': height,
'width': width,
'video_id': video_id}
if label is not None:
sample['label'] = label
if object_label is not None:
sample['object_label'] = object_label
if preceding_frame_label is not None:
sample['preceding_frame_label'] = preceding_frame_label
if not is_training:
# Original image is only used during visualization.
sample['original_image'] = original_image
if is_training:
if first_frame_finetuning:
keep_input = tf.constant(True)
else:
keep_input = tf.logical_and(sampling_is_valid, tf.logical_and(
_has_enough_pixels_of_each_object_in_first_frame(
label, decoder_output_stride),
_has_foreground_and_background_in_first_frame_2(
label, decoder_output_stride)))
batched = tf.train.maybe_batch(sample,
keep_input=keep_input,
batch_size=batch_size,
num_threads=num_threads,
capacity=batch_capacity_factor * batch_size,
dynamic_pad=True)
else:
batched = tf.train.batch(sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=batch_capacity_factor * batch_size,
dynamic_pad=True)
# Flatten from [batch, num_frames_per_video, ...] to
# batch * num_frames_per_video, ...].
cropped_height = train_utils.resolve_shape(batched['image'])[2]
cropped_width = train_utils.resolve_shape(batched['image'])[3]
if num_frames_per_video is None:
first_dim = -1
else:
first_dim = batch_size * num_frames_per_video
batched['image'] = tf.reshape(batched['image'],
[first_dim, cropped_height, cropped_width, 3])
if label is not None:
batched['label'] = tf.reshape(batched['label'],
[first_dim, cropped_height, cropped_width, 1])
return batched
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Segmentation results evaluation and visualization for videos using attention.
"""
import math
import os
import time
import numpy as np
import tensorflow as tf
from feelvos import common
from feelvos import model
from feelvos.datasets import video_dataset
from feelvos.utils import embedding_utils
from feelvos.utils import eval_utils
from feelvos.utils import video_input_generator
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('vis_batch_size', 1,
'The number of images in each batch during evaluation.')
flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
flags.DEFINE_integer('output_stride', 8,
'The ratio of input to output spatial resolution.')
flags.DEFINE_string('dataset', 'davis_2016',
'Name of the segmentation dataset.')
flags.DEFINE_string('vis_split', 'val',
'Which split of the dataset used for visualizing results')
flags.DEFINE_string(
'dataset_dir',
'/cns/is-d/home/lcchen/data/pascal_voc_seg/example_sstables',
'Where the dataset resides.')
flags.DEFINE_integer('num_vis_examples', -1,
'Number of examples for visualization. If -1, use all '
'samples in the vis data.')
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_bool('save_segmentations', False, 'Whether to save the '
'segmentation masks as '
'png images. Might be slow '
'on cns.')
flags.DEFINE_bool('save_embeddings', False, 'Whether to save the embeddings as'
'pickle. Might be slow on cns.')
flags.DEFINE_bool('eval_once_and_quit', False,
'Whether to just run the eval a single time and quit '
'afterwards. Otherwise, the eval is run in a loop with '
'new checkpoints.')
flags.DEFINE_boolean('first_frame_finetuning', False,
'Whether to only sample the first frame for fine-tuning.')
# the folder where segmentations are saved.
_SEGMENTATION_SAVE_FOLDER = 'segmentation'
_EMBEDDINGS_SAVE_FOLDER = 'embeddings'
def _process_seq_data(segmentation_dir, embeddings_dir, seq_name,
predicted_labels, gt_labels, embeddings):
"""Calculates the sequence IoU and optionally save the segmentation masks.
Args:
segmentation_dir: Directory in which the segmentation results are stored.
embeddings_dir: Directory in which the embeddings are stored.
seq_name: String, the name of the sequence.
predicted_labels: Int64 np.array of shape [n_frames, height, width].
gt_labels: Ground truth labels, Int64 np.array of shape
[n_frames, height, width].
embeddings: Float32 np.array of embeddings of shape
[n_frames, decoder_height, decoder_width, embedding_dim], or None.
Returns:
The IoU for the sequence (float).
"""
sequence_dir = os.path.join(segmentation_dir, seq_name)
tf.gfile.MakeDirs(sequence_dir)
embeddings_seq_dir = os.path.join(embeddings_dir, seq_name)
tf.gfile.MakeDirs(embeddings_seq_dir)
label_set = np.unique(gt_labels[0])
ious = []
assert len(predicted_labels) == len(gt_labels)
if embeddings is not None:
assert len(predicted_labels) == len(embeddings)
for t, (predicted_label, gt_label) in enumerate(
zip(predicted_labels, gt_labels)):
if FLAGS.save_segmentations:
seg_filename = os.path.join(segmentation_dir, seq_name, '%05d.png' % t)
eval_utils.save_segmentation_with_colormap(seg_filename, predicted_label)
if FLAGS.save_embeddings:
embedding_filename = os.path.join(embeddings_dir, seq_name,
'%05d.npy' % t)
assert embeddings is not None
eval_utils.save_embeddings(embedding_filename, embeddings[t])
object_ious_t = eval_utils.calculate_multi_object_ious(
predicted_label, gt_label, label_set)
ious.append(object_ious_t)
# First and last frame are excluded in DAVIS eval.
seq_ious = np.mean(ious[1:-1], axis=0)
tf.logging.info('seq ious: %s %s', seq_name, seq_ious)
return seq_ious
def create_predictions(samples, reference_labels, first_frame_img,
model_options):
"""Predicts segmentation labels for each frame of the video.
Slower version than create_predictions_fast, but does support more options.
Args:
samples: Dictionary of input samples.
reference_labels: Int tensor of shape [1, height, width, 1].
first_frame_img: Float32 tensor of shape [height, width, 3].
model_options: An InternalModelOptions instance to configure models.
Returns:
predicted_labels: Int tensor of shape [time, height, width] of
predicted labels for each frame.
all_embeddings: Float32 tensor of shape
[time, height, width, embedding_dim], or None.
"""
def predict(args, imgs):
"""Predicts segmentation labels and softmax probabilities for each image.
Args:
args: A tuple of (predictions, softmax_probabilities), where predictions
is an int tensor of shape [1, h, w] and softmax_probabilities is a
float32 tensor of shape [1, h_decoder, w_decoder, n_objects].
imgs: Either a one-tuple of the image to predict labels for of shape
[h, w, 3], or pair of previous frame and current frame image.
Returns:
predictions: The predicted labels as int tensor of shape [1, h, w].
softmax_probabilities: The softmax probabilities of shape
[1, h_decoder, w_decoder, n_objects].
"""
if FLAGS.save_embeddings:
last_frame_predictions, last_softmax_probabilities, _ = args
else:
last_frame_predictions, last_softmax_probabilities = args
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
ref_labels_to_use = tf.concat(
[reference_labels, last_frame_predictions[..., tf.newaxis]],
axis=0)
else:
ref_labels_to_use = reference_labels
predictions, softmax_probabilities = model.predict_labels(
tf.stack((first_frame_img,) + imgs),
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
embedding_dimension=FLAGS.embedding_dimension,
reference_labels=ref_labels_to_use,
k_nearest_neighbors=FLAGS.k_nearest_neighbors,
use_softmax_feedback=FLAGS.use_softmax_feedback,
initial_softmax_feedback=last_softmax_probabilities,
embedding_seg_feature_dimension=
FLAGS.embedding_seg_feature_dimension,
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
also_return_softmax_probabilities=True,
num_frames_per_video=
(3 if FLAGS.also_attend_to_previous_frame or
FLAGS.use_softmax_feedback else 2),
normalize_nearest_neighbor_distances=
FLAGS.normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
use_local_previous_frame_attention=
FLAGS.use_local_previous_frame_attention,
previous_frame_attention_window_size=
FLAGS.previous_frame_attention_window_size,
use_first_frame_matching=FLAGS.use_first_frame_matching
)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
if FLAGS.save_embeddings:
names = [n.name for n in tf.get_default_graph().as_graph_def().node]
embedding_names = [x for x in names if 'embeddings' in x]
# This will crash when multi-scale inference is used.
assert len(embedding_names) == 1, len(embedding_names)
embedding_name = embedding_names[0] + ':0'
embeddings = tf.get_default_graph().get_tensor_by_name(embedding_name)
return predictions, softmax_probabilities, embeddings
else:
return predictions, softmax_probabilities
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, FLAGS.decoder_output_stride,
reduce_labels=False)
if FLAGS.save_embeddings:
decoder_height = tf.shape(init_softmax)[1]
decoder_width = tf.shape(init_softmax)[2]
n_frames = (3 if FLAGS.also_attend_to_previous_frame
or FLAGS.use_softmax_feedback else 2)
embeddings_init = tf.zeros((n_frames, decoder_height, decoder_width,
FLAGS.embedding_dimension))
init = (init_labels, init_softmax, embeddings_init)
else:
init = (init_labels, init_softmax)
# Do not eval the first frame again but concat the first frame ground
# truth instead.
if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
elems = (samples[common.IMAGE][:-1], samples[common.IMAGE][1:])
else:
elems = (samples[common.IMAGE][1:],)
res = tf.scan(predict, elems,
initializer=init,
parallel_iterations=1,
swap_memory=True)
if FLAGS.save_embeddings:
predicted_labels, _, all_embeddings = res
first_frame_embeddings = all_embeddings[0, 0, tf.newaxis]
other_frame_embeddings = all_embeddings[:, -1]
all_embeddings = tf.concat(
[first_frame_embeddings, other_frame_embeddings], axis=0)
else:
predicted_labels, _ = res
all_embeddings = None
predicted_labels = tf.concat([reference_labels[..., 0],
tf.squeeze(predicted_labels, axis=1)],
axis=0)
return predicted_labels, all_embeddings
def create_predictions_fast(samples, reference_labels, first_frame_img,
model_options):
"""Predicts segmentation labels for each frame of the video.
Faster version than create_predictions, but does not support all options.
Args:
samples: Dictionary of input samples.
reference_labels: Int tensor of shape [1, height, width, 1].
first_frame_img: Float32 tensor of shape [height, width, 3].
model_options: An InternalModelOptions instance to configure models.
Returns:
predicted_labels: Int tensor of shape [time, height, width] of
predicted labels for each frame.
all_embeddings: Float32 tensor of shape
[time, height, width, embedding_dim], or None.
Raises:
ValueError: If FLAGS.save_embeddings is True, FLAGS.use_softmax_feedback is
False, or FLAGS.also_attend_to_previous_frame is False.
"""
if FLAGS.save_embeddings:
raise ValueError('save_embeddings does not work with '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
if not FLAGS.use_softmax_feedback:
raise ValueError('use_softmax_feedback must be True for '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
if not FLAGS.also_attend_to_previous_frame:
raise ValueError('also_attend_to_previous_frame must be True for '
'create_predictions_fast. Use the slower '
'create_predictions instead.')
# Extract embeddings for first frame and prepare initial predictions.
first_frame_embeddings = embedding_utils.get_embeddings(
first_frame_img[tf.newaxis], model_options, FLAGS.embedding_dimension)
init_labels = tf.squeeze(reference_labels, axis=-1)
init_softmax = embedding_utils.create_initial_softmax_from_labels(
reference_labels, reference_labels, FLAGS.decoder_output_stride,
reduce_labels=False)
init = (init_labels, init_softmax, first_frame_embeddings)
def predict(args, img):
"""Predicts segmentation labels and softmax probabilities for each image.
Args:
args: tuple of
(predictions, softmax_probabilities, last_frame_embeddings), where
predictions is an int tensor of shape [1, h, w],
softmax_probabilities is a float32 tensor of shape
[1, h_decoder, w_decoder, n_objects],
and last_frame_embeddings is a float32 tensor of shape
[h_decoder, w_decoder, embedding_dimension].
img: Image to predict labels for of shape [h, w, 3].
Returns:
predictions: The predicted labels as int tensor of shape [1, h, w].
softmax_probabilities: The softmax probabilities of shape
[1, h_decoder, w_decoder, n_objects].
"""
(last_frame_predictions, last_softmax_probabilities,
prev_frame_embeddings) = args
ref_labels_to_use = tf.concat(
[reference_labels, last_frame_predictions[..., tf.newaxis]],
axis=0)
predictions, softmax_probabilities, embeddings = model.predict_labels(
img[tf.newaxis],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
embedding_dimension=FLAGS.embedding_dimension,
reference_labels=ref_labels_to_use,
k_nearest_neighbors=FLAGS.k_nearest_neighbors,
use_softmax_feedback=FLAGS.use_softmax_feedback,
initial_softmax_feedback=last_softmax_probabilities,
embedding_seg_feature_dimension=
FLAGS.embedding_seg_feature_dimension,
embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
also_return_softmax_probabilities=True,
num_frames_per_video=1,
normalize_nearest_neighbor_distances=
FLAGS.normalize_nearest_neighbor_distances,
also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
use_local_previous_frame_attention=
FLAGS.use_local_previous_frame_attention,
previous_frame_attention_window_size=
FLAGS.previous_frame_attention_window_size,
use_first_frame_matching=FLAGS.use_first_frame_matching,
also_return_embeddings=True,
ref_embeddings=(first_frame_embeddings, prev_frame_embeddings)
)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.int32)
return predictions, softmax_probabilities, embeddings
# Do not eval the first frame again but concat the first frame ground
# truth instead.
# If you have a lot of GPU memory, you can try to set swap_memory=False,
# and/or parallel_iterations=2.
elems = samples[common.IMAGE][1:]
res = tf.scan(predict, elems,
initializer=init,
parallel_iterations=1,
swap_memory=True)
predicted_labels, _, _ = res
predicted_labels = tf.concat([reference_labels[..., 0],
tf.squeeze(predicted_labels, axis=1)],
axis=0)
return predicted_labels
def main(unused_argv):
if FLAGS.vis_batch_size != 1:
raise ValueError('Only batch size 1 is supported for now')
data_type = 'tf_sequence_example'
# Get dataset-dependent information.
dataset = video_dataset.get_dataset(
FLAGS.dataset,
FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
data_type=data_type)
# Prepare for visualization.
tf.gfile.MakeDirs(FLAGS.vis_logdir)
segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER)
tf.gfile.MakeDirs(segmentation_dir)
embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
tf.gfile.MakeDirs(embeddings_dir)
num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0)
else FLAGS.num_vis_examples)
if FLAGS.first_frame_finetuning:
num_vis_examples = 1
tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph()
with g.as_default():
# Without setting device to CPU we run out of memory.
with tf.device('cpu:0'):
samples = video_input_generator.get(
dataset,
None,
None,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant,
preprocess_image_and_label=False,
remap_labels_to_reference_frame=False)
samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
first_frame_img = samples[common.IMAGE][0]
reference_labels = samples[common.LABEL][0, tf.newaxis]
gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
seq_name = samples[common.VIDEO_ID][0]
model_options = common.VideoModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
crop_size=None,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
all_embeddings = None
predicted_labels = create_predictions_fast(
samples, reference_labels, first_frame_img, model_options)
# If you need more options like saving embeddings, replace the call to
# create_predictions_fast with create_predictions.
tf.train.get_or_create_global_step()
saver = tf.train.Saver(slim.get_variables_to_restore())
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir,
init_op=tf.global_variables_initializer(),
summary_op=None,
summary_writer=None,
global_step=None,
saver=saver)
num_batches = int(
math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Infinite loop to visualize the results when new checkpoint is created.
while True:
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint)
all_ious = []
with sv.managed_session(FLAGS.master,
start_standard_services=False) as sess:
sv.start_queue_runners(sess)
sv.saver.restore(sess, last_checkpoint)
for batch in range(num_batches):
ops = [predicted_labels, gt_labels, seq_name]
if FLAGS.save_embeddings:
ops.append(all_embeddings)
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
res = sess.run(ops)
tf.logging.info('Forwarding done')
pred_labels_val, gt_labels_val, seq_name_val = res[:3]
if FLAGS.save_embeddings:
all_embeddings_val = res[3]
else:
all_embeddings_val = None
seq_ious = _process_seq_data(segmentation_dir, embeddings_dir,
seq_name_val, pred_labels_val,
gt_labels_val, all_embeddings_val)
all_ious.append(seq_ious)
all_ious = np.concatenate(all_ious, axis=0)
tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean())
tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
result_dir = FLAGS.vis_logdir + '/results/'
tf.gfile.MakeDirs(result_dir)
with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
f.write(str(all_ious))
if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
break
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir')
flags.mark_flag_as_required('vis_logdir')
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment