Unverified Commit 05ccaf88 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3521 from YknZhu/master

Add deeplab model in tensorflow models
parents 6571d16d 1e9b07d8
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualizes the segmentation results via specified color map.
Visualizes the semantic segmentation results by the color map
defined by the different datasets. Supported colormaps are:
1. PASCAL VOC semantic segmentation benchmark.
Website: http://host.robots.ox.ac.uk/pascal/VOC/
"""
import numpy as np
# Dataset names.
_CITYSCAPES = 'cityscapes'
_PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = {
_CITYSCAPES: 19,
_PASCAL: 256,
}
def create_cityscapes_label_colormap():
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.asarray([
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
])
return colormap
def get_pascal_name():
return _PASCAL
def get_cityscapes_name():
return _CITYSCAPES
def bit_get(val, idx):
"""Gets the bit value.
Args:
val: Input value, int or numpy int array.
idx: Which bit of the input val.
Returns:
The "idx"-th bit of input val.
"""
return (val >> idx) & 1
def create_pascal_label_colormap():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= bit_get(ind, channel) << shift
ind >>= 3
return colormap
def create_label_colormap(dataset=_PASCAL):
"""Creates a label colormap for the specified dataset.
Args:
dataset: The colormap used in the dataset.
Returns:
A numpy array of the dataset colormap.
Raises:
ValueError: If the dataset is not supported.
"""
if dataset == _PASCAL:
return create_pascal_label_colormap()
elif dataset == _CITYSCAPES:
return create_cityscapes_label_colormap()
else:
raise ValueError('Unsupported dataset.')
def label_to_color_image(label, dataset=_PASCAL):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
dataset: The colormap used in the dataset.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
raise ValueError('label value too large.')
colormap = create_label_colormap(dataset)
return colormap[label]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for get_dataset_colormap.py."""
import numpy as np
import tensorflow as tf
from deeplab.utils import get_dataset_colormap
class VisualizationUtilTest(tf.test.TestCase):
def testBitGet(self):
"""Test that if the returned bit value is correct."""
self.assertEqual(1, get_dataset_colormap.bit_get(9, 0))
self.assertEqual(0, get_dataset_colormap.bit_get(9, 1))
self.assertEqual(0, get_dataset_colormap.bit_get(9, 2))
self.assertEqual(1, get_dataset_colormap.bit_get(9, 3))
def testPASCALLabelColorMapValue(self):
"""Test the getd color map value."""
colormap = get_dataset_colormap.create_pascal_label_colormap()
# Only test a few sampled entries in the color map.
self.assertTrue(np.array_equal([128., 0., 128.], colormap[5, :]))
self.assertTrue(np.array_equal([128., 192., 128.], colormap[23, :]))
self.assertTrue(np.array_equal([128., 0., 192.], colormap[37, :]))
self.assertTrue(np.array_equal([224., 192., 192.], colormap[127, :]))
self.assertTrue(np.array_equal([192., 160., 192.], colormap[175, :]))
def testLabelToPASCALColorImage(self):
"""Test the value of the converted label value."""
label = np.array([[0, 16, 16], [52, 7, 52]])
expected_result = np.array([
[[0, 0, 0], [0, 64, 0], [0, 64, 0]],
[[0, 64, 192], [128, 128, 128], [0, 64, 192]]
])
colored_label = get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_pascal_name())
self.assertTrue(np.array_equal(expected_result, colored_label))
def testUnExpectedLabelValueForLabelToPASCALColorImage(self):
"""Raise ValueError when input value exceeds range."""
label = np.array([[120], [300]])
with self.assertRaises(ValueError):
get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_pascal_name())
def testUnExpectedLabelDimensionForLabelToPASCALColorImage(self):
"""Raise ValueError if input dimension is not correct."""
label = np.array([120])
with self.assertRaises(ValueError):
get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_pascal_name())
def testGetColormapForUnsupportedDataset(self):
with self.assertRaises(ValueError):
get_dataset_colormap.create_label_colormap('unsupported_dataset')
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
def _get_data(data_provider, dataset_split):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Saves an annotation as one png image.
This script saves an annotation as one png image, and has the option to add
colormap to the png image for better visualization.
"""
import numpy as np
import PIL.Image as img
import tensorflow as tf
from deeplab.utils import get_dataset_colormap
def save_annotation(label,
save_dir,
filename,
add_colormap=True,
colormap_type=get_dataset_colormap.get_pascal_name()):
"""Saves the given label to image on disk.
Args:
label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image.
save_dir: The directory to which the results will be saved.
filename: The image filename.
add_colormap: Add color map to the label or not.
colormap_type: Colormap type for visualization.
"""
# Add colormap for visualizing the prediction.
if add_colormap:
colored_label = get_dataset_colormap.label_to_color_image(
label, colormap_type)
else:
colored_label = label
pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
pil_image.save(f, 'PNG')
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for training."""
import tensorflow as tf
slim = tf.contrib.slim
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
labels,
num_classes,
ignore_label,
loss_weight=1.0,
upsample_logits=True,
scope=None):
"""Adds softmax cross entropy loss for logits of each scale.
Args:
scales_to_logits: A map from logits names for different scales to logits.
The logits have shape [batch, logits_height, logits_width, num_classes].
labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
num_classes: Integer, number of target classes.
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not.
scope: String, the scope for the loss.
Raises:
ValueError: Label or logits is None.
"""
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
for scale, logits in scales_to_logits.iteritems():
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
if upsample_logits:
# Label is not downsampled, and instead we upsample logits.
logits = tf.image.resize_bilinear(
logits, tf.shape(labels)[1:3], align_corners=True)
scaled_labels = labels
else:
# Label is downsampled to the same size as logits.
scaled_labels = tf.image.resize_nearest_neighbor(
labels, tf.shape(logits)[1:3], align_corners=True)
scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight
one_hot_labels = slim.one_hot_encoding(
scaled_labels, num_classes, on_value=1.0, off_value=0.0)
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
def get_model_init_fn(train_logdir,
tf_initial_checkpoint,
initialize_last_layer,
last_layers,
ignore_missing_vars=False):
"""Gets the function initializing model variables from a checkpoint.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
return slim.assign_from_checkpoint_fn(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
"""Gets the gradient multipliers.
The gradient multipliers will adjust the learning rates for model
variables. For the task of semantic segmentation, the models are
usually fine-tuned from the models trained on the task of image
classification. To fine-tune the models, we usually set larger (e.g.,
10 times larger) learning rate for the parameters of last layer.
Args:
last_layers: Scopes of last layers.
last_layer_gradient_multiplier: The gradient multiplier for last layers.
Returns:
The gradient multiplier map with variables as key, and multipliers as value.
"""
gradient_multipliers = {}
for var in slim.get_model_variables():
# Double the learning rate for biases.
if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2.
# Use larger learning rate for last layer variables.
for layer in last_layers:
if layer in var.op.name and 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
break
elif layer in var.op.name:
gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
break
return gradient_multipliers
def get_model_learning_rate(
learning_policy, base_learning_rate, learning_rate_decay_step,
learning_rate_decay_factor, training_number_of_steps, learning_power,
slow_start_step, slow_start_learning_rate):
"""Gets model's learning rate.
Computes the model's learning rate for different learning policy.
Right now, only "step" and "poly" are supported.
(1) The learning policy for "step" is computed as follows:
current_learning_rate = base_learning_rate *
learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
See tf.train.exponential_decay for details.
(2) The learning policy for "poly" is computed as follows:
current_learning_rate = base_learning_rate *
(1 - global_step / training_number_of_steps) ^ learning_power
Args:
learning_policy: Learning rate policy for training.
base_learning_rate: The base learning rate for model training.
learning_rate_decay_step: Decay the base learning rate at a fixed step.
learning_rate_decay_factor: The rate to decay the base learning rate.
training_number_of_steps: Number of steps for training.
learning_power: Power used for 'poly' learning policy.
slow_start_step: Training model with small learning rate for the first
few steps.
slow_start_learning_rate: The learning rate employed during slow start.
Returns:
Learning rate for the specified learning policy.
Raises:
ValueError: If learning policy is not recognized.
"""
global_step = tf.train.get_or_create_global_step()
if learning_policy == 'step':
learning_rate = tf.train.exponential_decay(
base_learning_rate,
global_step,
learning_rate_decay_step,
learning_rate_decay_factor,
staircase=True)
elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay(
base_learning_rate,
global_step,
training_number_of_steps,
end_learning_rate=0,
power=learning_power)
else:
raise ValueError('Unknown learning policy.')
# Employ small learning rate at the first few steps for warm start.
return tf.where(global_step < slow_start_step, slow_start_learning_rate,
learning_rate)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Segmentation results visualization on a given set of images.
See model.py for more details and usage.
"""
import math
import os.path
import time
import numpy as np
import tensorflow as tf
from deeplab import common
from deeplab import model
from deeplab.datasets import segmentation_dataset
from deeplab.utils import input_generator
from deeplab.utils import save_annotation
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
# Settings for log directories.
flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
# Settings for visualizing the model.
flags.DEFINE_integer('vis_batch_size', 1,
'The number of images in each batch during evaluation.')
flags.DEFINE_multi_integer('vis_crop_size', [513, 513],
'Crop size [height, width] for visualization.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. Note one could use different
# atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_integer('output_stride', 16,
'The ratio of input to output spatial resolution.')
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags.DEFINE_multi_float('eval_scales', [1.0],
'The scales to resize images for evaluation.')
# Change to True for adding flipped images during test.
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images for evaluation or not.')
# Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg',
'Name of the segmentation dataset.')
flags.DEFINE_string('vis_split', 'val',
'Which split of the dataset used for visualizing results')
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes'],
'Visualization colormap type.')
flags.DEFINE_boolean('also_save_raw_predictions', False,
'Also save raw predictions.')
flags.DEFINE_integer('max_number_of_iterations', 0,
'Maximum number of visualization iterations. Will loop '
'indefinitely upon nonpositive values.')
# The folder where semantic segmentation predictions are saved.
_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results'
# The folder where raw semantic segmentation predictions are saved.
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results'
# The format to save image.
_IMAGE_FORMAT = '%06d_image'
# The format to save prediction
_PREDICTION_FORMAT = '%06d_prediction'
# To evaluate Cityscapes results on the evaluation server, the labels used
# during training should be mapped to the labels for evaluation.
_CITYSCAPES_TRAIN_ID_TO_EVAL_ID = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 31, 32, 33]
def _convert_train_id_to_eval_id(prediction, train_id_to_eval_id):
"""Converts the predicted label for evaluation.
There are cases where the training labels are not equal to the evaluation
labels. This function is used to perform the conversion so that we could
evaluate the results on the evaluation server.
Args:
prediction: Semantic segmentation prediction.
train_id_to_eval_id: A list mapping from train id to evaluation id.
Returns:
Semantic segmentation prediction whose labels have been changed.
"""
converted_prediction = prediction.copy()
for train_id, eval_id in enumerate(train_id_to_eval_id):
converted_prediction[prediction == train_id] = eval_id
return converted_prediction
def _process_batch(sess, original_images, semantic_predictions, image_names,
image_heights, image_widths, image_id_offset, save_dir,
raw_save_dir, train_id_to_eval_id=None):
"""Evaluates one single batch qualitatively.
Args:
sess: TensorFlow session.
original_images: One batch of original images.
semantic_predictions: One batch of semantic segmentation predictions.
image_names: Image names.
image_heights: Image heights.
image_widths: Image widths.
image_id_offset: Image id offset for indexing images.
save_dir: The directory where the predictions will be saved.
raw_save_dir: The directory where the raw predictions will be saved.
train_id_to_eval_id: A list mapping from train id to eval id.
"""
(original_images,
semantic_predictions,
image_names,
image_heights,
image_widths) = sess.run([original_images, semantic_predictions,
image_names, image_heights, image_widths])
num_image = semantic_predictions.shape[0]
for i in range(num_image):
image_height = np.squeeze(image_heights[i])
image_width = np.squeeze(image_widths[i])
original_image = np.squeeze(original_images[i])
semantic_prediction = np.squeeze(semantic_predictions[i])
crop_semantic_prediction = semantic_prediction[:image_height, :image_width]
# Save image.
save_annotation.save_annotation(
original_image, save_dir, _IMAGE_FORMAT % (image_id_offset + i),
add_colormap=False)
# Save prediction.
save_annotation.save_annotation(
crop_semantic_prediction, save_dir,
_PREDICTION_FORMAT % (image_id_offset + i), add_colormap=True,
colormap_type=FLAGS.colormap_type)
if FLAGS.also_save_raw_predictions:
image_filename = image_names[i]
if train_id_to_eval_id is not None:
crop_semantic_prediction = _convert_train_id_to_eval_id(
crop_semantic_prediction,
train_id_to_eval_id)
save_annotation.save_annotation(
crop_semantic_prediction, raw_save_dir, image_filename,
add_colormap=False)
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir)
train_id_to_eval_id = None
if dataset.name == segmentation_dataset.get_cityscapes_dataset_name():
tf.logging.info('Cityscapes requires converting train_id to eval_id.')
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
# Prepare for visualization.
tf.gfile.MakeDirs(FLAGS.vis_logdir)
save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
tf.gfile.MakeDirs(save_dir)
raw_save_dir = os.path.join(
FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
tf.gfile.MakeDirs(raw_save_dir)
tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph()
with g.as_default():
samples = input_generator.get(dataset,
FLAGS.vis_crop_size,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
crop_size=FLAGS.vis_crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if tuple(FLAGS.eval_scales) == (1.0,):
tf.logging.info('Performing single-scale test.')
predictions = model.predict_labels(
samples[common.IMAGE],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Performing multi-scale test.')
predictions = model.predict_labels_multi_scale(
samples[common.IMAGE],
model_options=model_options,
eval_scales=FLAGS.eval_scales,
add_flipped_images=FLAGS.add_flipped_images)
predictions = predictions[common.OUTPUT_TYPE]
if FLAGS.min_resize_value and FLAGS.max_resize_value:
# Only support batch_size = 1, since we assume the dimensions of original
# image after tf.squeeze is [height, width, 3].
assert FLAGS.vis_batch_size == 1
# Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back.
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
original_image_shape = tf.shape(original_image)
predictions = tf.slice(
predictions,
[0, 0, 0],
[1, original_image_shape[0], original_image_shape[1]])
resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]),
tf.squeeze(samples[common.WIDTH])])
predictions = tf.squeeze(
tf.image.resize_images(tf.expand_dims(predictions, 3),
resized_shape,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True), 3)
tf.train.get_or_create_global_step()
saver = tf.train.Saver(slim.get_variables_to_restore())
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir,
init_op=tf.global_variables_initializer(),
summary_op=None,
summary_writer=None,
global_step=None,
saver=saver)
num_batches = int(math.ceil(
dataset.num_samples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Loop to visualize the results when new checkpoint is created.
num_iters = 0
while (FLAGS.max_number_of_iterations <= 0 or
num_iters < FLAGS.max_number_of_iterations):
num_iters += 1
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint)
with sv.managed_session(FLAGS.master,
start_standard_services=False) as sess:
sv.start_queue_runners(sess)
sv.saver.restore(sess, last_checkpoint)
image_id_offset = 0
for batch in range(num_batches):
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
_process_batch(sess=sess,
original_images=samples[common.ORIGINAL_IMAGE],
semantic_predictions=predictions,
image_names=samples[common.IMAGE_NAME],
image_heights=samples[common.HEIGHT],
image_widths=samples[common.WIDTH],
image_id_offset=image_id_offset,
save_dir=save_dir,
raw_save_dir=raw_save_dir,
train_id_to_eval_id=train_id_to_eval_id)
image_id_offset += FLAGS.vis_batch_size
tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir')
flags.mark_flag_as_required('vis_logdir')
flags.mark_flag_as_required('dataset_dir')
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment