Unverified Commit 05ccaf88 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3521 from YknZhu/master

Add deeplab model in tensorflow models
parents 6571d16d 1e9b07d8
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
/research/brain_coder/ @danabo /research/brain_coder/ @danabo
/research/cognitive_mapping_and_planning/ @s-gupta /research/cognitive_mapping_and_planning/ @s-gupta
/research/compression/ @nmjohn /research/compression/ @nmjohn
/research/deeplab/ @aquariusjay @yknzhu @gpapan
/research/delf/ @andrefaraujo /research/delf/ @andrefaraujo
/research/differential_privacy/ @panyx0718 /research/differential_privacy/ @panyx0718
/research/domain_adaptation/ @bousmalis @dmrd /research/domain_adaptation/ @bousmalis @dmrd
......
# DeepLab: Deep Labelling for Semantic Image Segmentation
DeepLab is a state-of-art deep learning model for semantic image segmentation,
where the goal is to assign semantic labels (e.g., person, dog, cat and so on)
to every pixel in the input image. Current implementation includes the following
features:
1. DeepLabv1 [1]: We use *atrous convolution* to explicitly control the
resolution at which feature responses are computed within Deep Convolutional
Neural Networks.
2. DeepLabv2 [2]: We use *atrous spatial pyramid pooling* (ASPP) to robustly
segment objects at multiple scales with filters at multiple sampling rates
and effective fields-of-views.
3. DeepLabv3 [3]: We augment the ASPP module with *image-level feature* [5, 6]
to capture longer range information. We also include *batch normalization*
[7] parameters to facilitate the training. In particular, we applying atrous
convolution to extract output features at different output strides during
training and evaluation, which efficiently enables training BN at output
stride = 16 and attains a high performance at output stride = 8 during
evaluation.
4. DeepLabv3+ [4]: We extend DeepLabv3 to include a simple yet effective
decoder module to refine the segmentation results especially along object
boundaries. Furthermore, in this encoder-decoder structure one can
arbitrarily control the resolution of extracted encoder features by atrous
convolution to trade-off precision and runtime.
If you find the code useful for your research, please consider citing our latest
work:
```
@article{deeplabv3plus2018,
title={Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation},
author={Liang-Chieh Chen and Yukun Zhu and George Papandreou and Florian Schroff and Hartwig Adam},
journal={arXiv:1802.02611},
year={2018}
}
```
In the current implementation, we support adopting the following network
backbones:
1. MobileNetv2 [8]: A fast network structure designed for mobile devices. **We
will provide MobileNetv2 support in the next update. Please stay tuned.**
2. Xception [9, 10]: A powerful network structure intended for server-side
deployment.
This directory contains our TensorFlow [11] implementation. We provide codes
allowing users to train the model, evaluate results in terms of mIOU (mean
intersection-over-union), and visualize segmentation results. We use PASCAL VOC
2012 [12] and Cityscapes [13] semantic segmentation benchmarks as an example in
the code.
Some segmentation results on Flickr images:
<p align="center">
<img src="g3doc/img/vis1.png" width=600></br>
<img src="g3doc/img/vis2.png" width=600></br>
<img src="g3doc/img/vis3.png" width=600></br>
</p>
## Contacts (Maintainers)
* Liang-Chieh Chen, github: [aquariusjay](https://github.com/aquariusjay)
* YuKun Zhu, github: [yknzhu](https://github.com/YknZhu)
* George Papandreou, github: [gpapan](https://github.com/gpapan)
## Tables of Contents
Demo:
* <a href='deeplab_demo.ipynb'>Jupyter notebook for off-the-shelf inference.</a><br>
Running:
* <a href='g3doc/installation.md'>Installation.</a><br>
* <a href='g3doc/pascal.md'>Running DeepLab on PASCAL VOC 2012 semantic segmentation dataset.</a><br>
* <a href='g3doc/cityscapes.md'>Running DeepLab on Cityscapes semantic segmentation dataset.</a><br>
Models:
* <a href='g3doc/model_zoo.md'>Checkpoints and frozen inference graphs.</a><br>
Misc:
* Please check <a href='g3doc/faq.md'>FAQ</a> if you have some questions before reporting the issues.<br>
## Getting Help
To get help with issues you may encounter while using the DeepLab Tensorflow
implementation, create a new question on
[StackOverflow](https://stackoverflow.com/) with the tags "tensorflow" and
"deeplab".
Please report bugs (i.e., broken code, not usage questions) to the
tensorflow/models GitHub [issue
tracker](https://github.com/tensorflow/models/issues), prefixing the issue name
with "deeplab".
## References
1. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**<br />
Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, Alan L. Yuille (+ equal
contribution). <br />
[[link]](https://arxiv.org/abs/1412.7062). In ICLR, 2015.
2. **DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,**
**Atrous Convolution, and Fully Connected CRFs** <br />
Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, and Alan L Yuille (+ equal
contribution). <br />
[[link]](http://arxiv.org/abs/1606.00915). TPAMI 2017.
3. **Rethinking Atrous Convolution for Semantic Image Segmentation**<br />
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.<br />
[[link]](http://arxiv.org/abs/1706.05587). arXiv: 1706.05587, 2017.
4. **Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation**<br />
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam. arXiv: 1802.02611.<br />
[[link]](https://arxiv.org/abs/1802.02611). arXiv: 1802.02611, 2018.
5. **ParseNet: Looking Wider to See Better**<br />
Wei Liu, Andrew Rabinovich, Alexander C Berg<br />
[[link]](https://arxiv.org/abs/1506.04579). arXiv:1506.04579, 2015.
6. **Pyramid Scene Parsing Network**<br />
Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia<br />
[[link]](https://arxiv.org/abs/1612.01105). In CVPR, 2017.
7. **Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate shift**<br />
Sergey Ioffe, Christian Szegedy <br />
[[link]](https://arxiv.org/abs/1502.03167). In ICML, 2015.
8. **Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation**<br />
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen<br />
[[link]](https://arxiv.org/abs/1801.04381). arXiv:1801.04381, 2018.
9. **Xception: Deep Learning with Depthwise Separable Convolutions**<br />
François Chollet<br />
[[link]](https://arxiv.org/abs/1610.02357). In CVPR, 2017.
10. **Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge 2017 Entry**<br />
Haozhi Qi, Zheng Zhang, Bin Xiao, Han Hu, Bowen Cheng, Yichen Wei, Jifeng Dai<br />
[[link]](http://presentations.cocodataset.org/COCO17-Detect-MSRA.pdf). ICCV COCO Challenge
Workshop, 2017.
11. **Tensorflow: Large-Scale Machine Learning on Heterogeneous Distributed Systems**<br />
M. Abadi, A. Agarwal, et al. <br />
[[link]](https://arxiv.org/abs/1603.04467). arXiv:1603.04467, 2016.
12. **The Pascal Visual Object Classes Challenge – A Retrospective,** <br />
Mark Everingham, S. M. Ali Eslami, Luc Van Gool, Christopher K. I. Williams, John
Winn, and Andrew Zisserma. <br />
[[link]](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). IJCV, 2014.
13. **The Cityscapes Dataset for Semantic Urban Scene Understanding**<br />
Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele. <br />
[[link]](https://www.cityscapes-dataset.com/). In CVPR, 2016.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides flags that are common to scripts.
Common flags from train/eval/vis/export_model.py are collected in this script.
"""
import collections
import tensorflow as tf
flags = tf.app.flags
# Flags for input preprocessing.
flags.DEFINE_integer('min_resize_value', None,
'Desired size of the smaller image side.')
flags.DEFINE_integer('max_resize_value', None,
'Maximum allowed size of the larger image side.')
flags.DEFINE_integer('resize_factor', None,
'Resized dimensions are multiple of factor plus one.')
# Model dependent flags.
flags.DEFINE_integer('logits_kernel_size', 1,
'The kernel size for the convolutional kernel that '
'generates logits.')
# We will support `mobilenet_v2' in the coming update. When using
# 'xception_65', we set atrous_rates = [6, 12, 18] (output stride 16) and
# decoder_output_stride = 4.
flags.DEFINE_enum('model_variant', 'xception_65', ['xception_65'],
'DeepLab model variants.')
flags.DEFINE_multi_float('image_pyramid', None,
'Input scales for multi-scale feature extraction.')
flags.DEFINE_boolean('add_image_level_feature', True,
'Add image level feature.')
flags.DEFINE_boolean('aspp_with_batch_norm', True,
'Use batch norm parameters for ASPP or not.')
flags.DEFINE_boolean('aspp_with_separable_conv', True,
'Use separable convolution for ASPP or not.')
flags.DEFINE_multi_integer('multi_grid', None,
'Employ a hierarchy of atrous rates for ResNet.')
# For `xception_65`, use decoder_output_stride = 4.
flags.DEFINE_integer('decoder_output_stride', None,
'The ratio of input to output spatial resolution when '
'employing decoder to refine segmentation results.')
flags.DEFINE_boolean('decoder_use_separable_conv', True,
'Employ separable convolution for decoder or not.')
flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
'Scheme to merge multi scale features.')
FLAGS = flags.FLAGS
# Constants
# Perform semantic segmentation predictions.
OUTPUT_TYPE = 'semantic'
# Semantic segmentation item names.
LABELS_CLASS = 'labels_class'
IMAGE = 'image'
HEIGHT = 'height'
WIDTH = 'width'
IMAGE_NAME = 'image_name'
LABEL = 'label'
ORIGINAL_IMAGE = 'original_image'
# Test set name.
TEST_SET = 'test'
class ModelOptions(
collections.namedtuple('ModelOptions', [
'outputs_to_num_classes',
'crop_size',
'atrous_rates',
'output_stride',
'merge_method',
'add_image_level_feature',
'aspp_with_batch_norm',
'aspp_with_separable_conv',
'multi_grid',
'decoder_output_stride',
'decoder_use_separable_conv',
'logits_kernel_size',
'model_variant'
])):
"""Immutable class to hold model options."""
__slots__ = ()
def __new__(cls,
outputs_to_num_classes,
crop_size=None,
atrous_rates=None,
output_stride=8):
"""Constructor to set default values.
Args:
outputs_to_num_classes: A dictionary from output type to the number of
classes. For example, for the task of semantic segmentation with 21
semantic classes, we would have outputs_to_num_classes['semantic'] = 21.
crop_size: A tuple [crop_height, crop_width].
atrous_rates: A list of atrous convolution rates for ASPP.
output_stride: The ratio of input to output spatial resolution.
Returns:
A new ModelOptions instance.
"""
return super(ModelOptions, cls).__new__(
cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
FLAGS.merge_method, FLAGS.add_image_level_feature,
FLAGS.aspp_with_batch_norm, FLAGS.aspp_with_separable_conv,
FLAGS.multi_grid, FLAGS.decoder_output_stride,
FLAGS.decoder_use_separable_conv, FLAGS.logits_kernel_size,
FLAGS.model_variant)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extracts features for different models."""
import functools
import tensorflow as tf
from deeplab.core import xception
slim = tf.contrib.slim
# A map from network name to network function.
networks_map = {
'xception_65': xception.xception_65,
}
# A map from network name to network arg scope.
arg_scopes_map = {
'xception_65': xception.xception_arg_scope,
}
# Names for end point features.
DECODER_END_POINTS = 'decoder_end_points'
# A dictionary from network name to a map of end point features.
networks_to_feature_maps = {
'xception_65': {
DECODER_END_POINTS: [
'entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise',
],
}
}
# A map from feature extractor name to the network name scope used in the
# ImageNet pretrained versions of these models.
name_scope = {
'xception_65': 'xception_65',
}
# Mean pixel value.
_MEAN_RGB = [123.15, 115.90, 103.06]
def _preprocess_subtract_imagenet_mean(inputs):
"""Subtract Imagenet mean RGB value."""
mean_rgb = tf.reshape(_MEAN_RGB, [1, 1, 1, 3])
return inputs - mean_rgb
def _preprocess_zero_mean_unit_range(inputs):
"""Map image values from [0, 255] to [-1, 1]."""
return (2.0 / 255.0) * tf.to_float(inputs) - 1.0
_PREPROCESS_FN = {
'xception_65': _preprocess_zero_mean_unit_range,
}
def mean_pixel(model_variant=None):
"""Gets mean pixel value.
This function returns different mean pixel value, depending on the input
model_variant which adopts different preprocessing functions. We currently
handle the following preprocessing functions:
(1) _preprocess_subtract_imagenet_mean. We simply return mean pixel value.
(2) _preprocess_zero_mean_unit_range. We return [127.5, 127.5, 127.5].
The return values are used in a way that the padded regions after
pre-processing will contain value 0.
Args:
model_variant: Model variant (string) for feature extraction. For
backwards compatibility, model_variant=None returns _MEAN_RGB.
Returns:
Mean pixel value.
"""
if model_variant is None:
return _MEAN_RGB
else:
return [127.5, 127.5, 127.5]
def extract_features(images,
output_stride=8,
multi_grid=None,
model_variant=None,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False,
regularize_depthwise=False,
preprocess_images=True,
num_classes=None,
global_pool=False):
"""Extracts features by the parituclar model_variant.
Args:
images: A tensor of size [batch, height, width, channels].
output_stride: The ratio of input to output spatial resolution.
multi_grid: Employ a hierarchy of different atrous rates within network.
model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
regularize_depthwise: Whether or not apply L2-norm regularization on the
depthwise convolution weights.
preprocess_images: Performs preprocessing on images or not. Defaults to
True. Set to False if preprocessing will be done by other functions. We
supprot two types of preprocessing: (1) Mean pixel substraction and (2)
Pixel values normalization to be [-1, 1].
num_classes: Number of classes for image classification task. Defaults
to None for dense prediction tasks.
global_pool: Global pooling for image classification task. Defaults to
False, since dense prediction tasks do not use this.
Returns:
features: A tensor of size [batch, feature_height, feature_width,
feature_channels], where feature_height/feature_width are determined
by the images height/width and output_stride.
end_points: A dictionary from components of the network to the corresponding
activation.
Raises:
ValueError: Unrecognized model variant.
"""
if 'xception' in model_variant:
arg_scope = arg_scopes_map[model_variant](
weight_decay=weight_decay,
batch_norm_decay=0.9997,
batch_norm_epsilon=1e-3,
batch_norm_scale=True,
regularize_depthwise=regularize_depthwise)
features, end_points = get_network(
model_variant, preprocess_images, arg_scope)(
inputs=images,
num_classes=num_classes,
is_training=(is_training and fine_tune_batch_norm),
global_pool=global_pool,
output_stride=output_stride,
regularize_depthwise=regularize_depthwise,
multi_grid=multi_grid,
reuse=reuse,
scope=name_scope[model_variant])
elif 'mobilenet' in model_variant:
raise ValueError('MobileNetv2 support is coming soon.')
else:
raise ValueError('Unknown model variant %s.' % model_variant)
return features, end_points
def get_network(network_name, preprocess_images, arg_scope=None):
"""Gets the network.
Args:
network_name: Network name.
preprocess_images: Preprocesses the images or not.
arg_scope: Optional, arg_scope to build the network. If not provided the
default arg_scope of the network would be used.
Returns:
A network function that is used to extract features.
Raises:
ValueError: network is not supported.
"""
if network_name not in networks_map:
raise ValueError('Unsupported network %s.' % network_name)
arg_scope = arg_scope or arg_scopes_map[network_name]()
def _identity_function(inputs):
return inputs
if preprocess_images:
preprocess_function = _PREPROCESS_FN[network_name]
else:
preprocess_function = _identity_function
func = networks_map[network_name]
@functools.wraps(func)
def network_fn(inputs, *args, **kwargs):
with slim.arg_scope(arg_scope):
return func(preprocess_function(inputs), *args, **kwargs)
return network_fn
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions related to preprocessing inputs."""
import tensorflow as tf
def flip_dim(tensor_list, prob=0.5, dim=1):
"""Randomly flips a dimension of the given tensor.
The decision to randomly flip the `Tensors` is made together. In other words,
all or none of the images pass in are flipped.
Note that tf.random_flip_left_right and tf.random_flip_up_down isn't used so
that we can control for the probability as well as ensure the same decision
is applied across the images.
Args:
tensor_list: A list of `Tensors` with the same number of dimensions.
prob: The probability of a left-right flip.
dim: The dimension to flip, 0, 1, ..
Returns:
outputs: A list of the possibly flipped `Tensors` as well as an indicator
`Tensor` at the end whose value is `True` if the inputs were flipped and
`False` otherwise.
Raises:
ValueError: If dim is negative or greater than the dimension of a `Tensor`.
"""
random_value = tf.random_uniform([])
def flip():
flipped = []
for tensor in tensor_list:
if dim < 0 or dim >= len(tensor.get_shape().as_list()):
raise ValueError('dim must represent a valid dimension.')
flipped.append(tf.reverse_v2(tensor, [dim]))
return flipped
is_flipped = tf.less_equal(random_value, prob)
outputs = tf.cond(is_flipped, flip, lambda: tensor_list)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs.append(is_flipped)
return outputs
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
target_width, pad_value):
"""Pads the given image with the given pad_value.
Works like tf.image.pad_to_bounding_box, except it can pad the image
with any given arbitrary pad value and also handle images whose sizes are not
known during graph construction.
Args:
image: 3-D tensor with shape [height, width, channels]
offset_height: Number of rows of zeros to add on top.
offset_width: Number of columns of zeros to add on the left.
target_height: Height of output image.
target_width: Width of output image.
pad_value: Value to pad the image tensor with.
Returns:
3-D tensor of shape [target_height, target_width, channels].
Raises:
ValueError: If the shape of image is incompatible with the offset_* or
target_* arguments.
"""
image_rank = tf.rank(image)
image_rank_assert = tf.Assert(
tf.equal(image_rank, 3),
['Wrong image tensor rank [Expected] [Actual]',
3, image_rank])
with tf.control_dependencies([image_rank_assert]):
image -= pad_value
image_shape = tf.shape(image)
height, width = image_shape[0], image_shape[1]
target_width_assert = tf.Assert(
tf.greater_equal(
target_width, width),
['target_width must be >= width'])
target_height_assert = tf.Assert(
tf.greater_equal(target_height, height),
['target_height must be >= height'])
with tf.control_dependencies([target_width_assert]):
after_padding_width = target_width - offset_width - width
with tf.control_dependencies([target_height_assert]):
after_padding_height = target_height - offset_height - height
offset_assert = tf.Assert(
tf.logical_and(
tf.greater_equal(after_padding_width, 0),
tf.greater_equal(after_padding_height, 0)),
['target size not possible with the given target offsets'])
height_params = tf.stack([offset_height, after_padding_height])
width_params = tf.stack([offset_width, after_padding_width])
channel_params = tf.stack([0, 0])
with tf.control_dependencies([offset_assert]):
paddings = tf.stack([height_params, width_params, channel_params])
padded = tf.pad(image, paddings)
return padded + pad_value
def _crop(image, offset_height, offset_width, crop_height, crop_width):
"""Crops the given image using the provided offsets and sizes.
Note that the method doesn't assume we know the input image size but it does
assume we know the input image rank.
Args:
image: an image of shape [height, width, channels].
offset_height: a scalar tensor indicating the height offset.
offset_width: a scalar tensor indicating the width offset.
crop_height: the height of the cropped image.
crop_width: the width of the cropped image.
Returns:
The cropped (and resized) image.
Raises:
ValueError: if `image` doesn't have rank of 3.
InvalidArgumentError: if the rank is not 3 or if the image dimensions are
less than the crop size.
"""
original_shape = tf.shape(image)
if len(image.get_shape().as_list()) != 3:
raise ValueError('input must have rank of 3')
original_channels = image.get_shape().as_list()[2]
rank_assertion = tf.Assert(
tf.equal(tf.rank(image), 3),
['Rank of image must be equal to 3.'])
with tf.control_dependencies([rank_assertion]):
cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]])
size_assertion = tf.Assert(
tf.logical_and(
tf.greater_equal(original_shape[0], crop_height),
tf.greater_equal(original_shape[1], crop_width)),
['Crop size greater than the image size.'])
offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))
# Use tf.slice instead of crop_to_bounding box as it accepts tensors to
# define the crop size.
with tf.control_dependencies([size_assertion]):
image = tf.slice(image, offsets, cropped_shape)
image = tf.reshape(image, cropped_shape)
image.set_shape([crop_height, crop_width, original_channels])
return image
def random_crop(image_list, crop_height, crop_width):
"""Crops the given list of images.
The function applies the same crop to each image in the list. This can be
effectively applied when there are multiple image inputs of the same
dimension such as:
image, depths, normals = random_crop([image, depths, normals], 120, 150)
Args:
image_list: a list of image tensors of the same dimension but possibly
varying channel.
crop_height: the new height.
crop_width: the new width.
Returns:
the image_list with cropped images.
Raises:
ValueError: if there are multiple image inputs provided with different size
or the images are smaller than the crop dimensions.
"""
if not image_list:
raise ValueError('Empty image_list.')
# Compute the rank assertions.
rank_assertions = []
for i in range(len(image_list)):
image_rank = tf.rank(image_list[i])
rank_assert = tf.Assert(
tf.equal(image_rank, 3),
['Wrong rank for tensor %s [expected] [actual]',
image_list[i].name, 3, image_rank])
rank_assertions.append(rank_assert)
with tf.control_dependencies([rank_assertions[0]]):
image_shape = tf.shape(image_list[0])
image_height = image_shape[0]
image_width = image_shape[1]
crop_size_assert = tf.Assert(
tf.logical_and(
tf.greater_equal(image_height, crop_height),
tf.greater_equal(image_width, crop_width)),
['Crop size greater than the image size.'])
asserts = [rank_assertions[0], crop_size_assert]
for i in range(1, len(image_list)):
image = image_list[i]
asserts.append(rank_assertions[i])
with tf.control_dependencies([rank_assertions[i]]):
shape = tf.shape(image)
height = shape[0]
width = shape[1]
height_assert = tf.Assert(
tf.equal(height, image_height),
['Wrong height for tensor %s [expected][actual]',
image.name, height, image_height])
width_assert = tf.Assert(
tf.equal(width, image_width),
['Wrong width for tensor %s [expected][actual]',
image.name, width, image_width])
asserts.extend([height_assert, width_assert])
# Create a random bounding box.
#
# Use tf.random_uniform and not numpy.random.rand as doing the former would
# generate random numbers at graph eval time, unlike the latter which
# generates random numbers at graph definition time.
with tf.control_dependencies(asserts):
max_offset_height = tf.reshape(image_height - crop_height + 1, [])
max_offset_width = tf.reshape(image_width - crop_width + 1, [])
offset_height = tf.random_uniform(
[], maxval=max_offset_height, dtype=tf.int32)
offset_width = tf.random_uniform(
[], maxval=max_offset_width, dtype=tf.int32)
return [_crop(image, offset_height, offset_width,
crop_height, crop_width) for image in image_list]
def get_random_scale(min_scale_factor, max_scale_factor, step_size):
"""Gets a random scale value.
Args:
min_scale_factor: Minimum scale value.
max_scale_factor: Maximum scale value.
step_size: The step size from minimum to maximum value.
Returns:
A random scale value selected between minimum and maximum value.
Raises:
ValueError: min_scale_factor has unexpected value.
"""
if min_scale_factor < 0 or min_scale_factor > max_scale_factor:
raise ValueError('Unexpected value of min_scale_factor.')
if min_scale_factor == max_scale_factor:
return tf.to_float(min_scale_factor)
# When step_size = 0, we sample the value uniformly from [min, max).
if step_size == 0:
return tf.random_uniform([1],
minval=min_scale_factor,
maxval=max_scale_factor)
# When step_size != 0, we randomly select one discrete value from [min, max].
num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1)
scale_factors = tf.lin_space(min_scale_factor, max_scale_factor, num_steps)
shuffled_scale_factors = tf.random_shuffle(scale_factors)
return shuffled_scale_factors[0]
def randomly_scale_image_and_label(image, label=None, scale=1.0):
"""Randomly scales image and label.
Args:
image: Image with shape [height, width, 3].
label: Label with shape [height, width, 1].
scale: The value to scale image and label.
Returns:
Scaled image and label.
"""
# No random scaling if scale == 1.
if scale == 1.0:
return image, label
image_shape = tf.shape(image)
new_dim = tf.to_int32(tf.to_float([image_shape[0], image_shape[1]]) * scale)
# Need squeeze and expand_dims because image interpolation takes
# 4D tensors as input.
image = tf.squeeze(tf.image.resize_bilinear(
tf.expand_dims(image, 0),
new_dim,
align_corners=True), [0])
if label is not None:
label = tf.squeeze(tf.image.resize_nearest_neighbor(
tf.expand_dims(label, 0),
new_dim,
align_corners=True), [0])
return image, label
def resolve_shape(tensor, rank=None, scope=None):
"""Fully resolves the shape of a Tensor.
Use as much as possible the shape components already known during graph
creation and resolve the remaining ones during runtime.
Args:
tensor: Input tensor whose shape we query.
rank: The rank of the tensor, provided that we know it.
scope: Optional name scope.
Returns:
shape: The full shape of the tensor.
"""
with tf.name_scope(scope, 'resolve_shape', [tensor]):
if rank is not None:
shape = tensor.get_shape().with_rank(rank).as_list()
else:
shape = tensor.get_shape().as_list()
if None in shape:
shape_dynamic = tf.shape(tensor)
for i in range(len(shape)):
if shape[i] is None:
shape[i] = shape_dynamic[i]
return shape
def resize_to_range(image,
label=None,
min_size=None,
max_size=None,
factor=None,
align_corners=True,
label_layout_is_chw=False,
scope=None,
method=tf.image.ResizeMethod.BILINEAR):
"""Resizes image or label so their sides are within the provided range.
The output size can be described by two cases:
1. If the image can be rescaled so its minimum size is equal to min_size
without the other side exceeding max_size, then do so.
2. Otherwise, resize so the largest side is equal to max_size.
An integer in `range(factor)` is added to the computed sides so that the
final dimensions are multiples of `factor` plus one.
Args:
image: A 3D tensor of shape [height, width, channels].
label: (optional) A 3D tensor of shape [height, width, channels] (default)
or [channels, height, width] when label_layout_is_chw = True.
min_size: (scalar) desired size of the smaller image side.
max_size: (scalar) maximum allowed size of the larger image side. Note
that the output dimension is no larger than max_size and may be slightly
smaller than min_size when factor is not None.
factor: Make output size multiple of factor plus one.
align_corners: If True, exactly align all 4 corners of input and output.
label_layout_is_chw: If true, the label has shape [channel, height, width].
We support this case because for some instance segmentation dataset, the
instance segmentation is saved as [num_instances, height, width].
scope: Optional name scope.
method: Image resize method. Defaults to tf.image.ResizeMethod.BILINEAR.
Returns:
A 3-D tensor of shape [new_height, new_width, channels], where the image
has been resized (with the specified method) so that
min(new_height, new_width) == ceil(min_size) or
max(new_height, new_width) == ceil(max_size).
Raises:
ValueError: If the image is not a 3D tensor.
"""
with tf.name_scope(scope, 'resize_to_range', [image]):
new_tensor_list = []
min_size = tf.to_float(min_size)
if max_size is not None:
max_size = tf.to_float(max_size)
# Modify the max_size to be a multiple of factor plus 1 and make sure the
# max dimension after resizing is no larger than max_size.
if factor is not None:
max_size = (max_size + (factor - (max_size - 1) % factor) % factor
- factor)
[orig_height, orig_width, _] = resolve_shape(image, rank=3)
orig_height = tf.to_float(orig_height)
orig_width = tf.to_float(orig_width)
orig_min_size = tf.minimum(orig_height, orig_width)
# Calculate the larger of the possible sizes
large_scale_factor = min_size / orig_min_size
large_height = tf.to_int32(tf.ceil(orig_height * large_scale_factor))
large_width = tf.to_int32(tf.ceil(orig_width * large_scale_factor))
large_size = tf.stack([large_height, large_width])
new_size = large_size
if max_size is not None:
# Calculate the smaller of the possible sizes, use that if the larger
# is too big.
orig_max_size = tf.maximum(orig_height, orig_width)
small_scale_factor = max_size / orig_max_size
small_height = tf.to_int32(tf.ceil(orig_height * small_scale_factor))
small_width = tf.to_int32(tf.ceil(orig_width * small_scale_factor))
small_size = tf.stack([small_height, small_width])
new_size = tf.cond(
tf.to_float(tf.reduce_max(large_size)) > max_size,
lambda: small_size,
lambda: large_size)
# Ensure that both output sides are multiples of factor plus one.
if factor is not None:
new_size += (factor - (new_size - 1) % factor) % factor
new_tensor_list.append(tf.image.resize_images(
image, new_size, method=method, align_corners=align_corners))
if label is not None:
if label_layout_is_chw:
# Input label has shape [channel, height, width].
resized_label = tf.expand_dims(label, 3)
resized_label = tf.image.resize_nearest_neighbor(
resized_label, new_size, align_corners=align_corners)
resized_label = tf.squeeze(resized_label, 3)
else:
# Input label has shape [height, width, channel].
resized_label = tf.image.resize_images(
label, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=align_corners)
new_tensor_list.append(resized_label)
else:
new_tensor_list.append(None)
return new_tensor_list
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for preprocess_utils."""
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors
from deeplab.core import preprocess_utils
class PreprocessUtilsTest(tf.test.TestCase):
def testNoFlipWhenProbIsZero(self):
numpy_image = np.dstack([[[5., 6.],
[9., 0.]],
[[4., 3.],
[3., 5.]]])
image = tf.convert_to_tensor(numpy_image)
with self.test_session():
actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=0)
self.assertAllEqual(numpy_image, actual.eval())
self.assertAllEqual(False, is_flipped.eval())
actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=1)
self.assertAllEqual(numpy_image, actual.eval())
self.assertAllEqual(False, is_flipped.eval())
actual, is_flipped = preprocess_utils.flip_dim([image], prob=0, dim=2)
self.assertAllEqual(numpy_image, actual.eval())
self.assertAllEqual(False, is_flipped.eval())
def testFlipWhenProbIsOne(self):
numpy_image = np.dstack([[[5., 6.],
[9., 0.]],
[[4., 3.],
[3., 5.]]])
dim0_flipped = np.dstack([[[9., 0.],
[5., 6.]],
[[3., 5.],
[4., 3.]]])
dim1_flipped = np.dstack([[[6., 5.],
[0., 9.]],
[[3., 4.],
[5., 3.]]])
dim2_flipped = np.dstack([[[4., 3.],
[3., 5.]],
[[5., 6.],
[9., 0.]]])
image = tf.convert_to_tensor(numpy_image)
with self.test_session():
actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=0)
self.assertAllEqual(dim0_flipped, actual.eval())
self.assertAllEqual(True, is_flipped.eval())
actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=1)
self.assertAllEqual(dim1_flipped, actual.eval())
self.assertAllEqual(True, is_flipped.eval())
actual, is_flipped = preprocess_utils.flip_dim([image], prob=1, dim=2)
self.assertAllEqual(dim2_flipped, actual.eval())
self.assertAllEqual(True, is_flipped.eval())
def testFlipMultipleImagesConsistentlyWhenProbIsOne(self):
numpy_image = np.dstack([[[5., 6.],
[9., 0.]],
[[4., 3.],
[3., 5.]]])
numpy_label = np.dstack([[[0., 1.],
[2., 3.]]])
image_dim1_flipped = np.dstack([[[6., 5.],
[0., 9.]],
[[3., 4.],
[5., 3.]]])
label_dim1_flipped = np.dstack([[[1., 0.],
[3., 2.]]])
image = tf.convert_to_tensor(numpy_image)
label = tf.convert_to_tensor(numpy_label)
with self.test_session() as sess:
image, label, is_flipped = preprocess_utils.flip_dim(
[image, label], prob=1, dim=1)
actual_image, actual_label = sess.run([image, label])
self.assertAllEqual(image_dim1_flipped, actual_image)
self.assertAllEqual(label_dim1_flipped, actual_label)
self.assertEqual(True, is_flipped.eval())
def testReturnRandomFlipsOnMultipleEvals(self):
numpy_image = np.dstack([[[5., 6.],
[9., 0.]],
[[4., 3.],
[3., 5.]]])
dim1_flipped = np.dstack([[[6., 5.],
[0., 9.]],
[[3., 4.],
[5., 3.]]])
image = tf.convert_to_tensor(numpy_image)
tf.set_random_seed(53)
with self.test_session() as sess:
actual, is_flipped = preprocess_utils.flip_dim(
[image], prob=0.5, dim=1)
actual_image, actual_is_flipped = sess.run([actual, is_flipped])
self.assertAllEqual(numpy_image, actual_image)
self.assertEqual(False, actual_is_flipped)
actual_image, actual_is_flipped = sess.run([actual, is_flipped])
self.assertAllEqual(dim1_flipped, actual_image)
self.assertEqual(True, actual_is_flipped)
def testReturnCorrectCropOfSingleImage(self):
np.random.seed(0)
height, width = 10, 20
image = np.random.randint(0, 256, size=(height, width, 3))
crop_height, crop_width = 2, 4
image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
[cropped] = preprocess_utils.random_crop([image_placeholder],
crop_height,
crop_width)
with self.test_session():
cropped_image = cropped.eval(feed_dict={image_placeholder: image})
# Ensure we can find the cropped image in the original:
is_found = False
for x in range(0, width - crop_width + 1):
for y in range(0, height - crop_height + 1):
if np.isclose(image[y:y+crop_height, x:x+crop_width, :],
cropped_image).all():
is_found = True
break
self.assertTrue(is_found)
def testRandomCropMaintainsNumberOfChannels(self):
np.random.seed(0)
crop_height, crop_width = 10, 20
image = np.random.randint(0, 256, size=(100, 200, 3))
tf.set_random_seed(37)
image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
[cropped] = preprocess_utils.random_crop(
[image_placeholder], crop_height, crop_width)
with self.test_session():
cropped_image = cropped.eval(feed_dict={image_placeholder: image})
self.assertTupleEqual(cropped_image.shape, (crop_height, crop_width, 3))
def testReturnDifferentCropAreasOnTwoEvals(self):
tf.set_random_seed(0)
crop_height, crop_width = 2, 3
image = np.random.randint(0, 256, size=(100, 200, 3))
image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
[cropped] = preprocess_utils.random_crop(
[image_placeholder], crop_height, crop_width)
with self.test_session():
crop0 = cropped.eval(feed_dict={image_placeholder: image})
crop1 = cropped.eval(feed_dict={image_placeholder: image})
self.assertFalse(np.isclose(crop0, crop1).all())
def testReturnConsistenCropsOfImagesInTheList(self):
tf.set_random_seed(0)
height, width = 10, 20
crop_height, crop_width = 2, 3
labels = np.linspace(0, height * width-1, height * width)
labels = labels.reshape((height, width, 1))
image = np.tile(labels, (1, 1, 3))
image_placeholder = tf.placeholder(tf.int32, shape=(None, None, 3))
label_placeholder = tf.placeholder(tf.int32, shape=(None, None, 1))
[cropped_image, cropped_label] = preprocess_utils.random_crop(
[image_placeholder, label_placeholder], crop_height, crop_width)
with self.test_session() as sess:
cropped_image, cropped_labels = sess.run([cropped_image, cropped_label],
feed_dict={
image_placeholder: image,
label_placeholder: labels})
for i in range(3):
self.assertAllEqual(cropped_image[:, :, i], cropped_labels.squeeze())
def testDieOnRandomCropWhenImagesWithDifferentWidth(self):
crop_height, crop_width = 2, 3
image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
cropped = preprocess_utils.random_crop(
[image1, image2], crop_height, crop_width)
with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
image2: np.random.rand(4, 6, 1)})
def testDieOnRandomCropWhenImagesWithDifferentHeight(self):
crop_height, crop_width = 2, 3
image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
cropped = preprocess_utils.random_crop(
[image1, image2], crop_height, crop_width)
with self.test_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'Wrong height for tensor'):
sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
image2: np.random.rand(3, 5, 1)})
def testDieOnRandomCropWhenCropSizeIsGreaterThanImage(self):
crop_height, crop_width = 5, 9
image1 = tf.placeholder(tf.float32, name='image1', shape=(None, None, 3))
image2 = tf.placeholder(tf.float32, name='image2', shape=(None, None, 1))
cropped = preprocess_utils.random_crop(
[image1, image2], crop_height, crop_width)
with self.test_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'Crop size greater than the image size.'):
sess.run(cropped, feed_dict={image1: np.random.rand(4, 5, 3),
image2: np.random.rand(4, 5, 1)})
def testReturnPaddedImageWithNonZeroPadValue(self):
for dtype in [np.int32, np.int64, np.float32, np.float64]:
image = np.dstack([[[5, 6],
[9, 0]],
[[4, 3],
[3, 5]]]).astype(dtype)
expected_image = np.dstack([[[255, 255, 255, 255, 255],
[255, 255, 255, 255, 255],
[255, 5, 6, 255, 255],
[255, 9, 0, 255, 255],
[255, 255, 255, 255, 255]],
[[255, 255, 255, 255, 255],
[255, 255, 255, 255, 255],
[255, 4, 3, 255, 255],
[255, 3, 5, 255, 255],
[255, 255, 255, 255, 255]]]).astype(dtype)
with self.test_session():
image_placeholder = tf.placeholder(tf.float32)
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 2, 1, 5, 5, 255)
self.assertAllClose(padded_image.eval(
feed_dict={image_placeholder: image}), expected_image)
def testReturnOriginalImageWhenTargetSizeIsEqualToImageSize(self):
image = np.dstack([[[5, 6],
[9, 0]],
[[4, 3],
[3, 5]]])
with self.test_session():
image_placeholder = tf.placeholder(tf.float32)
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 0, 0, 2, 2, 255)
self.assertAllClose(padded_image.eval(
feed_dict={image_placeholder: image}), image)
def testDieOnTargetSizeGreaterThanImageSize(self):
image = np.dstack([[[5, 6],
[9, 0]],
[[4, 3],
[3, 5]]])
with self.test_session():
image_placeholder = tf.placeholder(tf.float32)
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 0, 0, 2, 1, 255)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'target_width must be >= width'):
padded_image.eval(feed_dict={image_placeholder: image})
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 0, 0, 1, 2, 255)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'target_height must be >= height'):
padded_image.eval(feed_dict={image_placeholder: image})
def testDieIfTargetSizeNotPossibleWithGivenOffset(self):
image = np.dstack([[[5, 6],
[9, 0]],
[[4, 3],
[3, 5]]])
with self.test_session():
image_placeholder = tf.placeholder(tf.float32)
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 3, 0, 4, 4, 255)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'target size not possible with the given target offsets'):
padded_image.eval(feed_dict={image_placeholder: image})
def testDieIfImageTensorRankIsNotThree(self):
image = np.vstack([[5, 6],
[9, 0]])
with self.test_session():
image_placeholder = tf.placeholder(tf.float32)
padded_image = preprocess_utils.pad_to_bounding_box(
image_placeholder, 0, 0, 2, 2, 255)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
'Wrong image tensor rank'):
padded_image.eval(feed_dict={image_placeholder: image})
def testResizeTensorsToRange(self):
test_shapes = [[60, 40],
[15, 30],
[15, 50]]
min_size = 50
max_size = 100
factor = None
expected_shape_list = [(75, 50, 3),
(50, 100, 3),
(30, 100, 3)]
for i, test_shape in enumerate(test_shapes):
image = tf.random_normal([test_shape[0], test_shape[1], 3])
new_tensor_list = preprocess_utils.resize_to_range(
image=image,
label=None,
min_size=min_size,
max_size=max_size,
factor=factor,
align_corners=True)
with self.test_session() as session:
resized_image = session.run(new_tensor_list[0])
self.assertEqual(resized_image.shape, expected_shape_list[i])
def testResizeTensorsToRangeWithFactor(self):
test_shapes = [[60, 40],
[15, 30],
[15, 50]]
min_size = 50
max_size = 98
factor = 8
expected_image_shape_list = [(81, 57, 3),
(49, 97, 3),
(33, 97, 3)]
expected_label_shape_list = [(81, 57, 1),
(49, 97, 1),
(33, 97, 1)]
for i, test_shape in enumerate(test_shapes):
image = tf.random_normal([test_shape[0], test_shape[1], 3])
label = tf.random_normal([test_shape[0], test_shape[1], 1])
new_tensor_list = preprocess_utils.resize_to_range(
image=image,
label=label,
min_size=min_size,
max_size=max_size,
factor=factor,
align_corners=True)
with self.test_session() as session:
new_tensor_list = session.run(new_tensor_list)
self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
def testResizeTensorsToRangeWithFactorAndLabelShapeCHW(self):
test_shapes = [[60, 40],
[15, 30],
[15, 50]]
min_size = 50
max_size = 98
factor = 8
expected_image_shape_list = [(81, 57, 3),
(49, 97, 3),
(33, 97, 3)]
expected_label_shape_list = [(5, 81, 57),
(5, 49, 97),
(5, 33, 97)]
for i, test_shape in enumerate(test_shapes):
image = tf.random_normal([test_shape[0], test_shape[1], 3])
label = tf.random_normal([5, test_shape[0], test_shape[1]])
new_tensor_list = preprocess_utils.resize_to_range(
image=image,
label=label,
min_size=min_size,
max_size=max_size,
factor=factor,
align_corners=True,
label_layout_is_chw=True)
with self.test_session() as session:
new_tensor_list = session.run(new_tensor_list)
self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
def testResizeTensorsToRangeWithSimilarMinMaxSizes(self):
test_shapes = [[60, 40],
[15, 30],
[15, 50]]
# Values set so that one of the side = 97.
min_size = 96
max_size = 98
factor = 8
expected_image_shape_list = [(97, 65, 3),
(49, 97, 3),
(33, 97, 3)]
expected_label_shape_list = [(97, 65, 1),
(49, 97, 1),
(33, 97, 1)]
for i, test_shape in enumerate(test_shapes):
image = tf.random_normal([test_shape[0], test_shape[1], 3])
label = tf.random_normal([test_shape[0], test_shape[1], 1])
new_tensor_list = preprocess_utils.resize_to_range(
image=image,
label=label,
min_size=min_size,
max_size=max_size,
factor=factor,
align_corners=True)
with self.test_session() as session:
new_tensor_list = session.run(new_tensor_list)
self.assertEqual(new_tensor_list[0].shape, expected_image_shape_list[i])
self.assertEqual(new_tensor_list[1].shape, expected_label_shape_list[i])
if __name__ == '__main__':
tf.test.main()
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for xception.py."""
import numpy as np
import tensorflow as tf
from deeplab.core import xception
from tensorflow.contrib.slim.nets import resnet_utils
slim = tf.contrib.slim
def create_test_input(batch, height, width, channels):
"""Create test input tensor."""
if None in [batch, height, width, channels]:
return tf.placeholder(tf.float32, (batch, height, width, channels))
else:
return tf.to_float(
np.tile(
np.reshape(
np.reshape(np.arange(height), [height, 1]) +
np.reshape(np.arange(width), [1, width]),
[1, height, width, 1]),
[batch, 1, 1, channels]))
class UtilityFunctionTest(tf.test.TestCase):
def testSeparableConv2DSameWithInputEvenSize(self):
n, n2 = 4, 2
# Input image.
x = create_test_input(1, n, n, 1)
# Convolution kernel.
dw = create_test_input(1, 3, 3, 1)
dw = tf.reshape(dw, [3, 3, 1, 1])
tf.get_variable('Conv/depthwise_weights', initializer=dw)
tf.get_variable('Conv/pointwise_weights',
initializer=tf.ones([1, 1, 1, 1]))
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
tf.get_variable_scope().reuse_variables()
y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=1, scope='Conv')
y1_expected = tf.to_float([[14, 28, 43, 26],
[28, 48, 66, 37],
[43, 66, 84, 46],
[26, 37, 46, 22]])
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
y2 = resnet_utils.subsample(y1, 2)
y2_expected = tf.to_float([[14, 43],
[43, 84]])
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
regularize_depthwise=True,
stride=2, scope='Conv')
y3_expected = y2_expected
y4 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=2, scope='Conv')
y4_expected = tf.to_float([[48, 37],
[37, 22]])
y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
self.assertAllClose(y3.eval(), y3_expected.eval())
self.assertAllClose(y4.eval(), y4_expected.eval())
def testSeparableConv2DSameWithInputOddSize(self):
n, n2 = 5, 3
# Input image.
x = create_test_input(1, n, n, 1)
# Convolution kernel.
dw = create_test_input(1, 3, 3, 1)
dw = tf.reshape(dw, [3, 3, 1, 1])
tf.get_variable('Conv/depthwise_weights', initializer=dw)
tf.get_variable('Conv/pointwise_weights',
initializer=tf.ones([1, 1, 1, 1]))
tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
tf.get_variable_scope().reuse_variables()
y1 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=1, scope='Conv')
y1_expected = tf.to_float([[14, 28, 43, 58, 34],
[28, 48, 66, 84, 46],
[43, 66, 84, 102, 55],
[58, 84, 102, 120, 64],
[34, 46, 55, 64, 30]])
y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
y2 = resnet_utils.subsample(y1, 2)
y2_expected = tf.to_float([[14, 43, 34],
[43, 84, 55],
[34, 55, 30]])
y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
y3 = xception.separable_conv2d_same(x, 1, 3, depth_multiplier=1,
regularize_depthwise=True,
stride=2, scope='Conv')
y3_expected = y2_expected
y4 = slim.separable_conv2d(x, 1, [3, 3], depth_multiplier=1,
stride=2, scope='Conv')
y4_expected = y2_expected
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
self.assertAllClose(y3.eval(), y3_expected.eval())
self.assertAllClose(y4.eval(), y4_expected.eval())
class XceptionNetworkTest(tf.test.TestCase):
"""Tests with small Xception network."""
def _xception_small(self,
inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
regularize_depthwise=True,
reuse=None,
scope='xception_small'):
"""A shallow and thin Xception for faster tests."""
block = xception.xception_block
blocks = [
block('entry_flow/block1',
depth_list=[1, 1, 1],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
block('entry_flow/block2',
depth_list=[2, 2, 2],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
block('entry_flow/block3',
depth_list=[4, 4, 4],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1),
block('entry_flow/block4',
depth_list=[4, 4, 4],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
block('middle_flow/block1',
depth_list=[4, 4, 4],
skip_connection_type='sum',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=2,
stride=1),
block('exit_flow/block1',
depth_list=[8, 8, 8],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
block('exit_flow/block2',
depth_list=[16, 16, 16],
skip_connection_type='none',
activation_fn_in_separable_conv=True,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1),
]
return xception.xception(inputs,
blocks=blocks,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
output_stride=output_stride,
reuse=reuse,
scope=scope)
def testClassificationEndPoints(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(xception.xception_arg_scope()):
logits, end_points = self._xception_small(
inputs,
num_classes=num_classes,
global_pool=global_pool,
scope='xception')
self.assertTrue(
logits.op.name.startswith('xception/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
self.assertTrue('global_pool' in end_points)
self.assertListEqual(end_points['global_pool'].get_shape().as_list(),
[2, 1, 1, 16])
def testEndpointNames(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
num_classes=num_classes,
global_pool=global_pool,
scope='xception')
expected = [
'xception/entry_flow/conv1_1',
'xception/entry_flow/conv1_2',
'xception/entry_flow/block1/unit_1/xception_module/separable_conv1',
'xception/entry_flow/block1/unit_1/xception_module/separable_conv2',
'xception/entry_flow/block1/unit_1/xception_module/separable_conv3',
'xception/entry_flow/block1/unit_1/xception_module/shortcut',
'xception/entry_flow/block1/unit_1/xception_module',
'xception/entry_flow/block1',
'xception/entry_flow/block2/unit_1/xception_module/separable_conv1',
'xception/entry_flow/block2/unit_1/xception_module/separable_conv2',
'xception/entry_flow/block2/unit_1/xception_module/separable_conv3',
'xception/entry_flow/block2/unit_1/xception_module/shortcut',
'xception/entry_flow/block2/unit_1/xception_module',
'xception/entry_flow/block2',
'xception/entry_flow/block3/unit_1/xception_module/separable_conv1',
'xception/entry_flow/block3/unit_1/xception_module/separable_conv2',
'xception/entry_flow/block3/unit_1/xception_module/separable_conv3',
'xception/entry_flow/block3/unit_1/xception_module/shortcut',
'xception/entry_flow/block3/unit_1/xception_module',
'xception/entry_flow/block3',
'xception/entry_flow/block4/unit_1/xception_module/separable_conv1',
'xception/entry_flow/block4/unit_1/xception_module/separable_conv2',
'xception/entry_flow/block4/unit_1/xception_module/separable_conv3',
'xception/entry_flow/block4/unit_1/xception_module/shortcut',
'xception/entry_flow/block4/unit_1/xception_module',
'xception/entry_flow/block4',
'xception/middle_flow/block1/unit_1/xception_module/separable_conv1',
'xception/middle_flow/block1/unit_1/xception_module/separable_conv2',
'xception/middle_flow/block1/unit_1/xception_module/separable_conv3',
'xception/middle_flow/block1/unit_1/xception_module',
'xception/middle_flow/block1/unit_2/xception_module/separable_conv1',
'xception/middle_flow/block1/unit_2/xception_module/separable_conv2',
'xception/middle_flow/block1/unit_2/xception_module/separable_conv3',
'xception/middle_flow/block1/unit_2/xception_module',
'xception/middle_flow/block1',
'xception/exit_flow/block1/unit_1/xception_module/separable_conv1',
'xception/exit_flow/block1/unit_1/xception_module/separable_conv2',
'xception/exit_flow/block1/unit_1/xception_module/separable_conv3',
'xception/exit_flow/block1/unit_1/xception_module/shortcut',
'xception/exit_flow/block1/unit_1/xception_module',
'xception/exit_flow/block1',
'xception/exit_flow/block2/unit_1/xception_module/separable_conv1',
'xception/exit_flow/block2/unit_1/xception_module/separable_conv2',
'xception/exit_flow/block2/unit_1/xception_module/separable_conv3',
'xception/exit_flow/block2/unit_1/xception_module',
'xception/exit_flow/block2',
'global_pool',
'xception/logits',
'predictions',
]
self.assertItemsEqual(end_points.keys(), expected)
def testClassificationShapes(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
num_classes,
global_pool=global_pool,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/conv1_1': [2, 112, 112, 32],
'xception/entry_flow/block1': [2, 56, 56, 1],
'xception/entry_flow/block2': [2, 28, 28, 2],
'xception/entry_flow/block4': [2, 14, 14, 4],
'xception/middle_flow/block1': [2, 14, 14, 4],
'xception/exit_flow/block1': [2, 7, 7, 8],
'xception/exit_flow/block2': [2, 7, 7, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
num_classes,
global_pool=global_pool,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/conv1_1': [2, 161, 161, 32],
'xception/entry_flow/block1': [2, 81, 81, 1],
'xception/entry_flow/block2': [2, 41, 41, 2],
'xception/entry_flow/block4': [2, 21, 21, 4],
'xception/middle_flow/block1': [2, 21, 21, 4],
'xception/exit_flow/block1': [2, 11, 11, 8],
'xception/exit_flow/block2': [2, 11, 11, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
output_stride = 8
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points = self._xception_small(
inputs,
num_classes,
global_pool=global_pool,
output_stride=output_stride,
scope='xception')
endpoint_to_shape = {
'xception/entry_flow/block1': [2, 81, 81, 1],
'xception/entry_flow/block2': [2, 41, 41, 2],
'xception/entry_flow/block4': [2, 41, 41, 4],
'xception/middle_flow/block1': [2, 41, 41, 4],
'xception/exit_flow/block1': [2, 41, 41, 8],
'xception/exit_flow/block2': [2, 41, 41, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self):
"""Verify dense feature extraction with atrous convolution."""
nominal_stride = 32
for output_stride in [4, 8, 16, 32, None]:
with slim.arg_scope(xception.xception_arg_scope()):
with tf.Graph().as_default():
with self.test_session() as sess:
tf.set_random_seed(0)
inputs = create_test_input(2, 96, 97, 3)
# Dense feature extraction followed by subsampling.
output, _ = self._xception_small(
inputs,
None,
is_training=False,
global_pool=False,
output_stride=output_stride)
if output_stride is None:
factor = 1
else:
factor = nominal_stride // output_stride
output = resnet_utils.subsample(output, factor)
# Make the two networks use the same weights.
tf.get_variable_scope().reuse_variables()
# Feature extraction at the nominal network rate.
expected, _ = self._xception_small(
inputs,
None,
is_training=False,
global_pool=False)
sess.run(tf.global_variables_initializer())
self.assertAllClose(output.eval(), expected.eval(),
atol=1e-5, rtol=1e-5)
def testUnknownBatchSize(self):
batch = 2
height, width = 65, 65
global_pool = True
num_classes = 10
inputs = create_test_input(None, height, width, 3)
with slim.arg_scope(xception.xception_arg_scope()):
logits, _ = self._xception_small(
inputs,
num_classes,
global_pool=global_pool,
scope='xception')
self.assertTrue(logits.op.name.startswith('xception/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 1, 1, num_classes))
def testFullyConvolutionalUnknownHeightWidth(self):
batch = 2
height, width = 65, 65
global_pool = False
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(xception.xception_arg_scope()):
output, _ = self._xception_small(
inputs,
None,
global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 16])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 3, 3, 16))
def testAtrousFullyConvolutionalUnknownHeightWidth(self):
batch = 2
height, width = 65, 65
global_pool = False
output_stride = 8
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(xception.xception_arg_scope()):
output, _ = self._xception_small(
inputs,
None,
global_pool=global_pool,
output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 16])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 9, 9, 16))
def testEndpointsReuse(self):
inputs = create_test_input(2, 32, 32, 3)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points0 = xception.xception_65(
inputs,
num_classes=10,
reuse=False)
with slim.arg_scope(xception.xception_arg_scope()):
_, end_points1 = xception.xception_65(
inputs,
num_classes=10,
reuse=True)
self.assertItemsEqual(end_points0.keys(), end_points1.keys())
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts Cityscapes data to TFRecord file format with Example protos.
The Cityscapes dataset is expected to have the following directory structure:
+ cityscapes
- build_cityscapes_data.py (current working directiory).
- build_data.py
+ cityscapesscripts
+ annotation
+ evaluation
+ helpers
+ preparation
+ viewer
+ gtFine
+ train
+ val
+ test
+ leftImg8bit
+ train
+ val
+ test
+ tfrecord
This script converts data into sharded data files and save at tfrecord folder.
Note that before running this script, the users should (1) register the
Cityscapes dataset website at https://www.cityscapes-dataset.com to
download the dataset, and (2) run the script provided by Cityscapes
`preparation/createTrainIdLabelImgs.py` to generate the training groundtruth.
Also note that the tensorflow model will be trained with `TrainId' instead
of `EvalId' used on the evaluation server. Thus, the users need to convert
the predicted labels to `EvalId` for evaluation on the server. See the
vis.py for more details.
The Example proto contains the following fields:
image/encoded: encoded image content.
image/filename: image filename.
image/format: image file format.
image/height: image height.
image/width: image width.
image/channels: image channels.
image/segmentation/class/encoded: encoded semantic segmentation content.
image/segmentation/class/format: semantic segmentation file format.
"""
import glob
import math
import os.path
import re
import sys
import build_data
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('cityscapes_root',
'./cityscapes',
'Cityscapes dataset root folder.')
tf.app.flags.DEFINE_string(
'output_dir',
'./tfrecord',
'Path to save converted SSTable of TensorFlow examples.')
_NUM_SHARDS = 10
# A map from data type to folder name that saves the data.
_FOLDERS_MAP = {
'image': 'leftImg8bit',
'label': 'gtFine',
}
# A map from data type to filename postfix.
_POSTFIX_MAP = {
'image': '_leftImg8bit',
'label': '_gtFine_labelTrainIds',
}
# A map from data type to data format.
_DATA_FORMAT_MAP = {
'image': 'png',
'label': 'png',
}
# Image file pattern.
_IMAGE_FILENAME_RE = re.compile('(.+)' + _POSTFIX_MAP['image'])
def _get_files(data, dataset_split):
"""Gets files for the specified data type and dataset split.
Args:
data: String, desired data ('image' or 'label').
dataset_split: String, dataset split ('train', 'val', 'test')
Returns:
A list of sorted file names or None when getting label for
test set.
"""
if data == 'label' and dataset_split == 'test':
return None
pattern = '*%s.%s' % (_POSTFIX_MAP[data], _DATA_FORMAT_MAP[data])
search_files = os.path.join(
FLAGS.cityscapes_root, _FOLDERS_MAP[data], dataset_split, '*', pattern)
filenames = glob.glob(search_files)
return sorted(filenames)
def _convert_dataset(dataset_split):
"""Converts the specified dataset split to TFRecord format.
Args:
dataset_split: The dataset split (e.g., train, val).
Raises:
RuntimeError: If loaded image and label have different shape, or if the
image file with specified postfix could not be found.
"""
image_files = _get_files('image', dataset_split)
label_files = _get_files('label', dataset_split)
num_images = len(image_files)
num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
image_reader = build_data.ImageReader('png', channels=3)
label_reader = build_data.ImageReader('png', channels=1)
for shard_id in range(_NUM_SHARDS):
shard_filename = '%s-%05d-of-%05d.tfrecord' % (
dataset_split, shard_id, _NUM_SHARDS)
output_filename = os.path.join(FLAGS.output_dir, shard_filename)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_idx = shard_id * num_per_shard
end_idx = min((shard_id + 1) * num_per_shard, num_images)
for i in range(start_idx, end_idx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, num_images, shard_id))
sys.stdout.flush()
# Read the image.
image_data = tf.gfile.FastGFile(image_files[i], 'r').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_data = tf.gfile.FastGFile(label_files[i], 'r').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.')
# Convert to tf example.
re_match = _IMAGE_FILENAME_RE.search(image_files[i])
if re_match is None:
raise RuntimeError('Invalid image filename: ' + image_files[i])
filename = os.path.basename(re_match.group(1))
example = build_data.image_seg_to_tfexample(
image_data, filename, height, width, seg_data)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def main(unused_argv):
# Only support converting 'train' and 'val' sets for now.
for dataset_split in ['train', 'val']:
_convert_dataset(dataset_split)
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common utility functions and classes for building dataset.
This script contains utility functions and classes to converts dataset to
TFRecord file format with Example protos.
The Example proto contains the following fields:
image/encoded: encoded image content.
image/filename: image filename.
image/format: image file format.
image/height: image height.
image/width: image width.
image/channels: image channels.
image/segmentation/class/encoded: encoded semantic segmentation content.
image/segmentation/class/format: semantic segmentation file format.
"""
import collections
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_enum('image_format', 'png', ['jpg', 'jpeg', 'png'],
'Image format.')
tf.app.flags.DEFINE_enum('label_format', 'png', ['png'],
'Segmentation label format.')
# A map from image format to expected data format.
_IMAGE_FORMAT_MAP = {
'jpg': 'jpeg',
'jpeg': 'jpeg',
'png': 'png',
}
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self, image_format='jpeg', channels=3):
"""Class constructor.
Args:
image_format: Image format. Only 'jpeg', 'jpg', or 'png' are supported.
channels: Image channels.
"""
with tf.Graph().as_default():
self._decode_data = tf.placeholder(dtype=tf.string)
self._image_format = image_format
self._session = tf.Session()
if self._image_format in ('jpeg', 'jpg'):
self._decode = tf.image.decode_jpeg(self._decode_data,
channels=channels)
elif self._image_format == 'png':
self._decode = tf.image.decode_png(self._decode_data,
channels=channels)
def read_image_dims(self, image_data):
"""Reads the image dimensions.
Args:
image_data: string of image data.
Returns:
image_height and image_width.
"""
image = self.decode_image(image_data)
return image.shape[:2]
def decode_image(self, image_data):
"""Decodes the image data string.
Args:
image_data: string of image data.
Returns:
Decoded image data.
Raises:
ValueError: Value of image channels not supported.
"""
image = self._session.run(self._decode,
feed_dict={self._decode_data: image_data})
if len(image.shape) != 3 or image.shape[2] not in (1, 3):
raise ValueError('The image channels not supported.')
return image
def _int64_list_feature(values):
"""Returns a TF-Feature of int64_list.
Args:
values: A scalar or list of values.
Returns:
A TF-Feature.
"""
if not isinstance(values, collections.Iterable):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def _bytes_list_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
"""Converts one image/segmentation pair to tf example.
Args:
image_data: string of image data.
filename: image filename.
height: image height.
width: image width.
seg_data: string of semantic segmentation data.
Returns:
tf example of one image/segmentation pair.
"""
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_list_feature(image_data),
'image/filename': _bytes_list_feature(filename),
'image/format': _bytes_list_feature(
_IMAGE_FORMAT_MAP[FLAGS.image_format]),
'image/height': _int64_list_feature(height),
'image/width': _int64_list_feature(width),
'image/channels': _int64_list_feature(3),
'image/segmentation/class/encoded': (
_bytes_list_feature(seg_data)),
'image/segmentation/class/format': _bytes_list_feature(
FLAGS.label_format),
}))
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts PASCAL VOC 2012 data to TFRecord file format with Example protos.
PASCAL VOC 2012 dataset is expected to have the following directory structure:
+ pascal_voc_seg
- build_data.py
- build_voc2012_data.py (current working directory).
+ VOCdevkit
+ VOC2012
+ JPEGImages
+ SegmentationClass
+ ImageSets
+ Segmentation
+ tfrecord
Image folder:
./VOCdevkit/VOC2012/JPEGImages
Semantic segmentation annotations:
./VOCdevkit/VOC2012/SegmentationClass
list folder:
./VOCdevkit/VOC2012/ImageSets/Segmentation
This script converts data into sharded data files and save at tfrecord folder.
The Example proto contains the following fields:
image/encoded: encoded image content.
image/filename: image filename.
image/format: image file format.
image/height: image height.
image/width: image width.
image/channels: image channels.
image/segmentation/class/encoded: encoded semantic segmentation content.
image/segmentation/class/format: semantic segmentation file format.
"""
import glob
import math
import os.path
import sys
import build_data
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('image_folder',
'./VOCdevkit/VOC2012/JPEGImages',
'Folder containing images.')
tf.app.flags.DEFINE_string(
'semantic_segmentation_folder',
'./VOCdevkit/VOC2012/SegmentationClassRaw',
'Folder containing semantic segmentation annotations.')
tf.app.flags.DEFINE_string(
'list_folder',
'./VOCdevkit/VOC2012/ImageSets/Segmentation',
'Folder containing lists for training and validation')
tf.app.flags.DEFINE_string(
'output_dir',
'./tfrecord',
'Path to save converted SSTable of TensorFlow examples.')
_NUM_SHARDS = 4
def _convert_dataset(dataset_split):
"""Converts the specified dataset split to TFRecord format.
Args:
dataset_split: The dataset split (e.g., train, test).
Raises:
RuntimeError: If loaded image and label have different shape.
"""
dataset = os.path.basename(dataset_split)[:-4]
sys.stdout.write('Processing ' + dataset)
filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
num_images = len(filenames)
num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
image_reader = build_data.ImageReader('jpeg', channels=3)
label_reader = build_data.ImageReader('png', channels=1)
for shard_id in range(_NUM_SHARDS):
output_filename = os.path.join(
FLAGS.output_dir,
'%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_idx = shard_id * num_per_shard
end_idx = min((shard_id + 1) * num_per_shard, num_images)
for i in range(start_idx, end_idx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, len(filenames), shard_id))
sys.stdout.flush()
# Read the image.
image_filename = os.path.join(
FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
image_data = tf.gfile.FastGFile(image_filename, 'r').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_filename = os.path.join(
FLAGS.semantic_segmentation_folder,
filenames[i] + '.' + FLAGS.label_format)
seg_data = tf.gfile.FastGFile(seg_filename, 'r').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.')
# Convert to tf example.
example = build_data.image_seg_to_tfexample(
image_data, filenames[i], height, width, seg_data)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def main(unused_argv):
dataset_splits = glob.glob(os.path.join(FLAGS.list_folder, '*.txt'))
for dataset_split in dataset_splits:
_convert_dataset(dataset_split)
if __name__ == '__main__':
tf.app.run()
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Script to preprocess the Cityscapes dataset. Note (1) the users should register
# the Cityscapes dataset website: https://www.cityscapes-dataset.com/downloads/ to
# download the dataset, and (2) the users should run the script provided by Cityscapes
# `preparation/createTrainIdLabelImgs.py` to generate the training groundtruth.
#
# Usage:
# bash ./preprocess_cityscapes.sh
#
# The folder structure is assumed to be:
# + data
# - build_cityscapes_data.py
# + cityscapes
# + cityscapesscripts
# + gtFine
# + leftImg8bit
#
# Exit immediately if a command exits with a non-zero status.
set -e
CURRENT_DIR=$(pwd)
WORK_DIR="."
cd "${CURRENT_DIR}"
# Root path for PASCAL VOC 2012 dataset.
CITYSCAPES_ROOT="${WORK_DIR}/cityscapes"
# Build TFRecords of the dataset.
# First, create output directory for storing TFRecords.
OUTPUT_DIR="${CITYSCAPES_ROOT}/tfrecord"
mkdir -p "${OUTPUT_DIR}"
BUILD_SCRIPT="${WORK_DIR}/build_cityscapes_data.py"
echo "Converting Cityscapes dataset..."
python "${BUILD_SCRIPT}" \
--cityscapes_root="${CITYSCAPES_ROOT}" \
--output_dir="${OUTPUT_DIR}" \
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Script to download and preprocess the PASCAL VOC 2012 dataset.
#
# Usage:
# bash ./download_and_preprocess_voc2012.sh
#
# The folder structure is assumed to be:
# + data
# - build_data.py
# - build_voc2012_data.py
# - download_and_preprocess_voc2012.sh
# - remove_gt_colormap.py
# + VOCdevkit
# + VOC2012
# + JPEGImages
# + SegmentationClass
#
# Exit immediately if a command exits with a non-zero status.
set -e
CURRENT_DIR=$(pwd)
WORK_DIR="./pascal_voc_seg"
mkdir -p ${WORK_DIR}
cd ${WORK_DIR}
# Helper function to download and unpack VOC 2012 dataset.
function download_and_uncompress() {
local BASE_URL=${1}
local FILENAME=${2}
if [ ! -f ${FILENAME} ]; then
echo "Downloading ${FILENAME} to ${WORK_DIR}"
wget -nd -c "${BASE_URL}/${FILENAME}"
fi
echo "Uncompressing ${FILENAME}"
tar -xf ${FILENAME}
}
# Download the images.
BASE_URL="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/"
FILENAME="VOCtrainval_11-May-2012.tar"
download_and_uncompress ${BASE_URL} ${FILENAME}
cd "${CURRENT_DIR}"
# Root path for PASCAL VOC 2012 dataset.
PASCAL_ROOT="${WORK_DIR}/VOCdevkit/VOC2012"
# Remove the colormap in the ground truth annotations.
SEG_FOLDER="${PASCAL_ROOT}/SegmentationClass"
SEMANTIC_SEG_FOLDER="${PASCAL_ROOT}/SegmentationClassRaw"
echo "Removing the color map in ground truth annotations..."
python ./remove_gt_colormap.py \
--original_gt_folder="${SEG_FOLDER}" \
--output_dir="${SEMANTIC_SEG_FOLDER}"
# Build TFRecords of the dataset.
# First, create output directory for storing TFRecords.
OUTPUT_DIR="${WORK_DIR}/tfrecord"
mkdir -p "${OUTPUT_DIR}"
IMAGE_FOLDER="${PASCAL_ROOT}/JPEGImages"
LIST_FOLDER="${PASCAL_ROOT}/ImageSets/Segmentation"
echo "Converting PASCAL VOC 2012 dataset..."
python ./build_voc2012_data.py \
--image_folder="${IMAGE_FOLDER}" \
--semantic_segmentation_folder="${SEMANTIC_SEG_FOLDER}" \
--list_folder="${LIST_FOLDER}" \
--image_format="jpg" \
--output_dir="${OUTPUT_DIR}"
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Removes the color map from segmentation annotations.
Removes the color map from the ground truth segmentation annotations and save
the results to output_dir.
"""
import glob
import os.path
import numpy as np
from PIL import Image
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('original_gt_folder',
'./VOCdevkit/VOC2012/SegmentationClass',
'Original ground truth annotations.')
tf.app.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')
tf.app.flags.DEFINE_string('output_dir',
'./VOCdevkit/VOC2012/SegmentationClassRaw',
'folder to save modified ground truth annotations.')
def _remove_colormap(filename):
"""Removes the color map from the annotation.
Args:
filename: Ground truth annotation filename.
Returns:
Annotation without color map.
"""
return np.array(Image.open(filename))
def _save_annotation(annotation, filename):
"""Saves the annotation as png file.
Args:
annotation: Segmentation annotation.
filename: Output filename.
"""
pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
with tf.gfile.Open(filename, mode='w') as f:
pil_image.save(f, 'PNG')
def main(unused_argv):
# Create the output directory if not exists.
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
annotations = glob.glob(os.path.join(FLAGS.original_gt_folder,
'*.' + FLAGS.segmentation_format))
for annotation in annotations:
raw_annotation = _remove_colormap(annotation)
filename = os.path.basename(annotation)[:-4]
_save_annotation(raw_annotation,
os.path.join(
FLAGS.output_dir,
filename + '.' + FLAGS.segmentation_format))
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data from semantic segmentation datasets.
The SegmentationDataset class provides both images and annotations (semantic
segmentation and/or instance segmentation) for TensorFlow. Currently, we
support the following datasets:
1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
(e.g., bike, person, and so on) and leaves all the other semantic classes as
one background class. The dataset contains 1464, 1449, and 1456 annotated
images for the training, validation and test respectively.
2. Cityscapes dataset (https://www.cityscapes-dataset.com)
The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
and so on) for urban street scenes.
References:
M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
and A. Zisserman, The pascal visual object classes challenge a retrospective.
IJCV, 2014.
M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
scene understanding," In Proc. of CVPR, 2016.
"""
import collections
import os.path
import tensorflow as tf
slim = tf.contrib.slim
dataset = slim.dataset
tfexample_decoder = slim.tfexample_decoder
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying height and width.',
'labels_class': ('A semantic segmentation label whose size matches image.'
'Its values range from 0 (background) to num_classes.'),
}
# Named tuple to describe the dataset properties.
DatasetDescriptor = collections.namedtuple(
'DatasetDescriptor',
['splits_to_sizes', # Splits of the dataset into training, val, and test.
'num_classes', # Number of semantic classes.
'ignore_label', # Ignore label value.
]
)
_CITYSCAPES_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 2975,
'val': 500,
},
num_classes=19,
ignore_label=255,
)
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 1464,
'trainval': 2913,
'val': 1449,
},
num_classes=21,
ignore_label=255,
)
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
}
# Default file pattern of TFRecord of TensorFlow Example.
_FILE_PATTERN = '%s-*'
def get_cityscapes_dataset_name():
return 'cityscapes'
def get_dataset(dataset_name, split_name, dataset_dir):
"""Gets an instance of slim Dataset.
Args:
dataset_name: Dataset name.
split_name: A train/val Split name.
dataset_dir: The directory of the dataset sources.
Returns:
An instance of slim Dataset.
Raises:
ValueError: if the dataset_name or split_name is not recognized.
"""
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name)
# Prepare the variables for different datasets.
num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Specify how the TF-Examples are decoded.
keys_to_features = {
'image/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/format': tf.FixedLenFeature(
(), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/segmentation/class/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/segmentation/class/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3),
'image_name': tfexample_decoder.Tensor('image/filename'),
'height': tfexample_decoder.Tensor('image/height'),
'width': tfexample_decoder.Tensor('image/width'),
'labels_class': tfexample_decoder.Image(
image_key='image/segmentation/class/encoded',
format_key='image/segmentation/class/format',
channels=1),
}
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=splits_to_sizes[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
ignore_label=ignore_label,
num_classes=num_classes,
name=dataset_name,
multi_label=True)
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DeepLab Demo\n",
"\n",
"This demo will demostrate the steps to run deeplab semantic segmentation model on sample input images.\n",
"\n",
"## Prerequisites\n",
"\n",
"Running this demo requires the following libraries:\n",
"\n",
"* Jupyter notebook (Python 2)\n",
"* Tensorflow (>= v1.5.0)\n",
"* Matplotlib\n",
"* Pillow\n",
"* numpy\n",
"* ipywidgets (follow the setup [here](https://ipywidgets.readthedocs.io/en/stable/user_install.html))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"import os\n",
"import StringIO\n",
"import sys\n",
"import tarfile\n",
"import tempfile\n",
"import urllib\n",
"\n",
"from IPython import display\n",
"from ipywidgets import interact\n",
"from ipywidgets import interactive\n",
"from matplotlib import gridspec\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"import tensorflow as tf\n",
"\n",
"if tf.__version__ < '1.5.0':\n",
" raise ImportError('Please upgrade your tensorflow installation to v1.5.0 or newer!')\n",
"\n",
"# Needed to show segmentation colormap labels\n",
"sys.path.append('utils')\n",
"import get_dataset_colormap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select and download models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_MODEL_URLS = {\n",
" 'xception_coco_voctrainaug': 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz',\n",
" 'xception_coco_voctrainval': 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz',\n",
"}\n",
"\n",
"Config = collections.namedtuple('Config', 'model_url, model_dir')\n",
"\n",
"def get_config(model_name, model_dir):\n",
" return Config(_MODEL_URLS[model_name], model_dir)\n",
"\n",
"config_widget = interactive(get_config, model_name=_MODEL_URLS.keys(), model_dir='')\n",
"display.display(config_widget)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check configuration and download the model\n",
"\n",
"_TARBALL_NAME = 'deeplab_model.tar.gz'\n",
"\n",
"config = config_widget.result\n",
"\n",
"model_dir = config.model_dir or tempfile.mkdtemp()\n",
"tf.gfile.MakeDirs(model_dir)\n",
"\n",
"download_path = os.path.join(model_dir, _TARBALL_NAME)\n",
"print 'downloading model to %s, this might take a while...' % download_path\n",
"urllib.urlretrieve(config.model_url, download_path)\n",
"print 'download completed!'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load model in TensorFlow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_FROZEN_GRAPH_NAME = 'frozen_inference_graph'\n",
"\n",
"\n",
"class DeepLabModel(object):\n",
" \"\"\"Class to load deeplab model and run inference.\"\"\"\n",
" \n",
" INPUT_TENSOR_NAME = 'ImageTensor:0'\n",
" OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'\n",
" INPUT_SIZE = 513\n",
"\n",
" def __init__(self, tarball_path):\n",
" \"\"\"Creates and loads pretrained deeplab model.\"\"\"\n",
" self.graph = tf.Graph()\n",
" \n",
" graph_def = None\n",
" # Extract frozen graph from tar archive.\n",
" tar_file = tarfile.open(tarball_path)\n",
" for tar_info in tar_file.getmembers():\n",
" if _FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):\n",
" file_handle = tar_file.extractfile(tar_info)\n",
" graph_def = tf.GraphDef.FromString(file_handle.read())\n",
" break\n",
"\n",
" tar_file.close()\n",
" \n",
" if graph_def is None:\n",
" raise RuntimeError('Cannot find inference graph in tar archive.')\n",
"\n",
" with self.graph.as_default(): \n",
" tf.import_graph_def(graph_def, name='')\n",
" \n",
" self.sess = tf.Session(graph=self.graph)\n",
" \n",
" def run(self, image):\n",
" \"\"\"Runs inference on a single image.\n",
" \n",
" Args:\n",
" image: A PIL.Image object, raw input image.\n",
" \n",
" Returns:\n",
" resized_image: RGB image resized from original input image.\n",
" seg_map: Segmentation map of `resized_image`.\n",
" \"\"\"\n",
" width, height = image.size\n",
" resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)\n",
" target_size = (int(resize_ratio * width), int(resize_ratio * height))\n",
" resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)\n",
" batch_seg_map = self.sess.run(\n",
" self.OUTPUT_TENSOR_NAME,\n",
" feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})\n",
" seg_map = batch_seg_map[0]\n",
" return resized_image, seg_map\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = DeepLabModel(download_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"LABEL_NAMES = np.asarray([\n",
" 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',\n",
" 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',\n",
" 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',\n",
" 'train', 'tv'\n",
"])\n",
"\n",
"FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)\n",
"FULL_COLOR_MAP = get_dataset_colormap.label_to_color_image(FULL_LABEL_MAP)\n",
"\n",
"\n",
"def vis_segmentation(image, seg_map):\n",
" plt.figure(figsize=(15, 5))\n",
" grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])\n",
"\n",
" plt.subplot(grid_spec[0])\n",
" plt.imshow(image)\n",
" plt.axis('off')\n",
" plt.title('input image')\n",
" \n",
" plt.subplot(grid_spec[1])\n",
" seg_image = get_dataset_colormap.label_to_color_image(\n",
" seg_map, get_dataset_colormap.get_pascal_name()).astype(np.uint8)\n",
" plt.imshow(seg_image)\n",
" plt.axis('off')\n",
" plt.title('segmentation map')\n",
"\n",
" plt.subplot(grid_spec[2])\n",
" plt.imshow(image)\n",
" plt.imshow(seg_image, alpha=0.7)\n",
" plt.axis('off')\n",
" plt.title('segmentation overlay')\n",
" \n",
" unique_labels = np.unique(seg_map)\n",
" ax = plt.subplot(grid_spec[3])\n",
" plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')\n",
" ax.yaxis.tick_right()\n",
" plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])\n",
" plt.xticks([], [])\n",
" ax.tick_params(width=0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run on sample images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Note that we are using single scale inference in the demo for fast\n",
"# computation, so the results may slightly differ from the visualizations\n",
"# in README, which uses multi-scale and left-right flipped inputs.\n",
"\n",
"IMAGE_DIR = 'g3doc/img'\n",
"\n",
"def run_demo_image(image_name):\n",
" try:\n",
" image_path = os.path.join(IMAGE_DIR, image_name)\n",
" orignal_im = Image.open(image_path)\n",
" except IOError:\n",
" print 'Failed to read image from %s.' % image_path \n",
" return \n",
" print 'running deeplab on image %s...' % image_name\n",
" resized_im, seg_map = model.run(orignal_im)\n",
" \n",
" vis_segmentation(resized_im, seg_map)\n",
"\n",
"_ = interact(run_demo_image, image_name=['image1.jpg', 'image2.jpg', 'image3.jpg'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run on internet images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def get_an_internet_image(url):\n",
" if not url:\n",
" return\n",
"\n",
" try:\n",
" # Prefix with 'file://' for local file.\n",
" if os.path.exists(url):\n",
" url = 'file://' + url\n",
" f = urllib.urlopen(url)\n",
" jpeg_str = f.read()\n",
" except IOError:\n",
" print 'invalid url: ' + url\n",
" return\n",
"\n",
" orignal_im = Image.open(StringIO.StringIO(jpeg_str))\n",
" print 'running deeplab on image %s...' % url\n",
" resized_im, seg_map = model.run(orignal_im)\n",
" \n",
" vis_segmentation(resized_im, seg_map)\n",
"\n",
"_ = interact(get_an_internet_image, url='')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for the DeepLab model.
See model.py for more details and usage.
"""
import math
import tensorflow as tf
from deeplab import common
from deeplab import model
from deeplab.datasets import segmentation_dataset
from deeplab.utils import input_generator
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
# Settings for log directories.
flags.DEFINE_string('eval_logdir', None, 'Where to write the event logs.')
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.')
# Settings for evaluating the model.
flags.DEFINE_integer('eval_batch_size', 1,
'The number of images in each batch during evaluation.')
flags.DEFINE_multi_integer('eval_crop_size', [513, 513],
'Image crop size [height, width] for evaluation.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. Note one could use different
# atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_integer('output_stride', 16,
'The ratio of input to output spatial resolution.')
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags.DEFINE_multi_float('eval_scales', [1.0],
'The scales to resize images for evaluation.')
# Change to True for adding flipped images during test.
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images for evaluation or not.')
# Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg',
'Name of the segmentation dataset.')
flags.DEFINE_string('eval_split', 'val',
'Which split of the dataset used for evaluation')
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
flags.DEFINE_integer('max_number_of_evaluations', 0,
'Maximum number of eval iterations. Will loop '
'indefinitely upon nonpositive values.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir)
tf.gfile.MakeDirs(FLAGS.eval_logdir)
tf.logging.info('Evaluating on %s set', FLAGS.eval_split)
with tf.Graph().as_default():
samples = input_generator.get(
dataset,
FLAGS.eval_crop_size,
FLAGS.eval_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.eval_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
crop_size=FLAGS.eval_crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if tuple(FLAGS.eval_scales) == (1.0,):
tf.logging.info('Performing single-scale test.')
predictions = model.predict_labels(samples[common.IMAGE], model_options,
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Performing multi-scale test.')
predictions = model.predict_labels_multi_scale(
samples[common.IMAGE],
model_options=model_options,
eval_scales=FLAGS.eval_scales,
add_flipped_images=FLAGS.add_flipped_images)
predictions = predictions[common.OUTPUT_TYPE]
predictions = tf.reshape(predictions, shape=[-1])
labels = tf.reshape(samples[common.LABEL], shape=[-1])
weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))
# Set ignore_label regions to label 0, because metrics.mean_iou requires
# range of labels = [0, dataset.num_classes). Note the ignore_lable regions
# are not evaluated since the corresponding regions contain weights = 0.
labels = tf.where(
tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)
predictions_tag = 'miou'
for eval_scale in FLAGS.eval_scales:
predictions_tag += '_' + str(eval_scale)
if FLAGS.add_flipped_images:
predictions_tag += '_flipped'
# Define the evaluation metric.
metric_map = {}
metric_map[predictions_tag] = tf.metrics.mean_iou(
predictions, labels, dataset.num_classes, weights=weights)
metrics_to_values, metrics_to_updates = (
tf.contrib.metrics.aggregate_metric_map(metric_map))
for metric_name, metric_value in metrics_to_values.iteritems():
slim.summaries.add_scalar_summary(
metric_value, metric_name, print_summary=True)
num_batches = int(
math.ceil(dataset.num_samples / float(FLAGS.eval_batch_size)))
tf.logging.info('Eval num images %d', dataset.num_samples)
tf.logging.info('Eval batch size %d and num batch %d',
FLAGS.eval_batch_size, num_batches)
num_eval_iters = None
if FLAGS.max_number_of_evaluations > 0:
num_eval_iters = FLAGS.max_number_of_evaluations
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
max_number_of_evaluations=num_eval_iters,
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir')
flags.mark_flag_as_required('eval_logdir')
flags.mark_flag_as_required('dataset_dir')
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment