Unverified Commit f798e4b5 authored by aquariusjay's avatar aquariusjay Committed by GitHub
Browse files

Merge pull request #4311 from YknZhu/master

PiperOrigin-RevId: 197225788
parents 9dec261e afb2a7dc
......@@ -104,14 +104,33 @@ Misc:
To get help with issues you may encounter while using the DeepLab Tensorflow
implementation, create a new question on
[StackOverflow](https://stackoverflow.com/) with the tags "tensorflow" and
"deeplab".
[StackOverflow](https://stackoverflow.com/) with the tag "tensorflow".
Please report bugs (i.e., broken code, not usage questions) to the
tensorflow/models GitHub [issue
tracker](https://github.com/tensorflow/models/issues), prefixing the issue name
with "deeplab".
## Change Logs
### May 18, 2018
1. Added builders for ResNet-v1 and Xception model variants.
1. Added ADE20K support, including colormap and pretrained Xception_65 checkpoint.
1. Fixed a bug on using non-default depth_multiplier for MobileNet-v2.
### March 22, 2018
Released checkpoints using MobileNet-V2 as network backbone and pretrained on
PASCAL VOC 2012 and Cityscapes.
### March 5, 2018
First release of DeepLab in TensorFlow including deeper Xception network
backbone. Included chekcpoints that have been pretrained on PASCAL VOC 2012
and Cityscapes.
## References
1. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**<br />
......
......@@ -95,6 +95,7 @@ ORIGINAL_IMAGE = 'original_image'
# Test set name.
TEST_SET = 'test'
class ModelOptions(
collections.namedtuple('ModelOptions', [
'outputs_to_num_classes',
......@@ -109,7 +110,8 @@ class ModelOptions(
'decoder_output_stride',
'decoder_use_separable_conv',
'logits_kernel_size',
'model_variant'
'model_variant',
'depth_multiplier',
])):
"""Immutable class to hold model options."""
......@@ -139,4 +141,4 @@ class ModelOptions(
FLAGS.aspp_with_batch_norm, FLAGS.aspp_with_separable_conv,
FLAGS.multi_grid, FLAGS.decoder_output_stride,
FLAGS.decoder_use_separable_conv, FLAGS.logits_kernel_size,
FLAGS.model_variant)
FLAGS.model_variant, FLAGS.depth_multiplier)
......@@ -17,8 +17,9 @@
import functools
import tensorflow as tf
from deeplab.core import resnet_v1_beta
from deeplab.core import xception
from nets.mobilenet import mobilenet as mobilenet_lib
from tensorflow.contrib.slim.nets import resnet_utils
from nets.mobilenet import mobilenet_v2
......@@ -56,10 +57,12 @@ def _mobilenet_v2(net,
"""
with tf.variable_scope(
scope, 'MobilenetV2', [net], reuse=reuse) as scope:
return mobilenet_lib.mobilenet_base(
return mobilenet_v2.mobilenet_base(
net,
conv_defs=mobilenet_v2.V2_DEF,
multiplier=depth_multiplier,
depth_multiplier=depth_multiplier,
min_depth=8 if depth_multiplier == 1.0 else 1,
divisible_by=8 if depth_multiplier == 1.0 else 1,
final_endpoint=final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT,
output_stride=output_stride,
scope=scope)
......@@ -68,13 +71,25 @@ def _mobilenet_v2(net,
# A map from network name to network function.
networks_map = {
'mobilenet_v2': _mobilenet_v2,
'resnet_v1_50': resnet_v1_beta.resnet_v1_50,
'resnet_v1_50_beta': resnet_v1_beta.resnet_v1_50_beta,
'resnet_v1_101': resnet_v1_beta.resnet_v1_101,
'resnet_v1_101_beta': resnet_v1_beta.resnet_v1_101_beta,
'xception_41': xception.xception_41,
'xception_65': xception.xception_65,
'xception_71': xception.xception_71,
}
# A map from network name to network arg scope.
arg_scopes_map = {
'mobilenet_v2': mobilenet_v2.training_scope,
'resnet_v1_50': resnet_utils.resnet_arg_scope,
'resnet_v1_50_beta': resnet_utils.resnet_arg_scope,
'resnet_v1_101': resnet_utils.resnet_arg_scope,
'resnet_v1_101_beta': resnet_utils.resnet_arg_scope,
'xception_41': xception.xception_arg_scope,
'xception_65': xception.xception_arg_scope,
'xception_71': xception.xception_arg_scope,
}
# Names for end point features.
......@@ -86,19 +101,49 @@ networks_to_feature_maps = {
# The provided checkpoint does not include decoder module.
DECODER_END_POINTS: None,
},
'resnet_v1_50': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'],
},
'resnet_v1_50_beta': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'],
},
'resnet_v1_101': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'],
},
'resnet_v1_101_beta': {
DECODER_END_POINTS: ['block1/unit_2/bottleneck_v1/conv3'],
},
'xception_41': {
DECODER_END_POINTS: [
'entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise',
],
},
'xception_65': {
DECODER_END_POINTS: [
'entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise',
],
}
},
'xception_71': {
DECODER_END_POINTS: [
'entry_flow/block2/unit_1/xception_module/'
'separable_conv2_pointwise',
],
},
}
# A map from feature extractor name to the network name scope used in the
# ImageNet pretrained versions of these models.
name_scope = {
'mobilenet_v2': 'MobilenetV2',
'resnet_v1_50': 'resnet_v1_50',
'resnet_v1_50_beta': 'resnet_v1_50',
'resnet_v1_101': 'resnet_v1_101',
'resnet_v1_101_beta': 'resnet_v1_101',
'xception_41': 'xception_41',
'xception_65': 'xception_65',
'xception_71': 'xception_71',
}
# Mean pixel value.
......@@ -118,7 +163,13 @@ def _preprocess_zero_mean_unit_range(inputs):
_PREPROCESS_FN = {
'mobilenet_v2': _preprocess_zero_mean_unit_range,
'resnet_v1_50': _preprocess_subtract_imagenet_mean,
'resnet_v1_50_beta': _preprocess_zero_mean_unit_range,
'resnet_v1_101': _preprocess_subtract_imagenet_mean,
'resnet_v1_101_beta': _preprocess_zero_mean_unit_range,
'xception_41': _preprocess_zero_mean_unit_range,
'xception_65': _preprocess_zero_mean_unit_range,
'xception_71': _preprocess_zero_mean_unit_range,
}
......@@ -140,7 +191,8 @@ def mean_pixel(model_variant=None):
Returns:
Mean pixel value.
"""
if model_variant is None:
if model_variant in ['resnet_v1_50',
'resnet_v1_101'] or model_variant is None:
return _MEAN_RGB
else:
return [127.5, 127.5, 127.5]
......@@ -159,7 +211,8 @@ def extract_features(images,
regularize_depthwise=False,
preprocess_images=True,
num_classes=None,
global_pool=False):
global_pool=False,
use_bounded_activations=False):
"""Extracts features by the particular model_variant.
Args:
......@@ -184,6 +237,8 @@ def extract_features(images,
to None for dense prediction tasks.
global_pool: Global pooling for image classification task. Defaults to
False, since dense prediction tasks do not use this.
use_bounded_activations: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns:
features: A tensor of size [batch, feature_height, feature_width,
......@@ -195,7 +250,25 @@ def extract_features(images,
Raises:
ValueError: Unrecognized model variant.
"""
if 'xception' in model_variant:
if 'resnet' in model_variant:
arg_scope = arg_scopes_map[model_variant](
weight_decay=weight_decay,
batch_norm_decay=0.95,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
activation_fn=tf.nn.relu6 if use_bounded_activations else tf.nn.relu)
features, end_points = get_network(
model_variant, preprocess_images, arg_scope)(
inputs=images,
num_classes=num_classes,
is_training=(is_training and fine_tune_batch_norm),
global_pool=global_pool,
output_stride=output_stride,
multi_grid=multi_grid,
reuse=reuse,
scope=name_scope[model_variant],
use_bounded_activations=use_bounded_activations)
elif 'xception' in model_variant:
arg_scope = arg_scopes_map[model_variant](
weight_decay=weight_decay,
batch_norm_decay=0.9997,
......
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for resnet_v1_beta module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import tensorflow as tf
from deeplab.core import resnet_v1_beta
from tensorflow.contrib.slim.nets import resnet_utils
slim = tf.contrib.slim
def create_test_input(batch, height, width, channels):
"""Create test input tensor."""
if None in [batch, height, width, channels]:
return tf.placeholder(tf.float32, (batch, height, width, channels))
else:
return tf.to_float(
np.tile(
np.reshape(
np.reshape(np.arange(height), [height, 1]) +
np.reshape(np.arange(width), [1, width]),
[1, height, width, 1]),
[batch, 1, 1, channels]))
class ResnetCompleteNetworkTest(tf.test.TestCase):
"""Tests with complete small ResNet v1 networks."""
def _resnet_small(self,
inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
store_non_strided_activations=False,
multi_grid=None,
reuse=None,
scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 for faster tests."""
if multi_grid is None:
multi_grid = [1, 1, 1]
else:
if len(multi_grid) != 3:
raise ValueError('Expect multi_grid to have length 3.')
block = resnet_v1_beta.resnet_v1_beta_block
blocks = [
block('block1', base_depth=1, num_units=3, stride=2),
block('block2', base_depth=2, num_units=3, stride=2),
block('block3', base_depth=4, num_units=3, stride=2),
resnet_utils.Block('block4', resnet_v1_beta.bottleneck, [
{'depth': 32,
'depth_bottleneck': 8,
'stride': 1,
'unit_rate': rate} for rate in multi_grid])]
return resnet_v1_beta.resnet_v1_beta(
inputs,
blocks,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
output_stride=output_stride,
root_block_fn=functools.partial(
resnet_v1_beta.root_block_fn_for_beta_variant),
store_non_strided_activations=store_non_strided_activations,
reuse=reuse,
scope=scope)
def testClassificationEndPoints(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationEndPointsWithMultigrid(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
multi_grid = [1, 2, 4]
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
multi_grid=multi_grid,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[2, 1, 1, num_classes])
def testClassificationShapes(self):
global_pool = True
num_classes = 10
inputs = create_test_input(2, 224, 224, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 112, 112, 64],
'resnet/conv1_2': [2, 112, 112, 64],
'resnet/conv1_3': [2, 112, 112, 128],
'resnet/block1': [2, 28, 28, 4],
'resnet/block2': [2, 14, 14, 8],
'resnet/block3': [2, 7, 7, 16],
'resnet/block4': [2, 7, 7, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 161, 161, 64],
'resnet/conv1_2': [2, 161, 161, 64],
'resnet/conv1_3': [2, 161, 161, 128],
'resnet/block1': [2, 41, 41, 4],
'resnet/block2': [2, 21, 21, 8],
'resnet/block3': [2, 11, 11, 16],
'resnet/block4': [2, 11, 11, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self):
global_pool = False
num_classes = 10
output_stride = 8
inputs = create_test_input(2, 321, 321, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
output_stride=output_stride,
scope='resnet')
endpoint_to_shape = {
'resnet/conv1_1': [2, 161, 161, 64],
'resnet/conv1_2': [2, 161, 161, 64],
'resnet/conv1_3': [2, 161, 161, 128],
'resnet/block1': [2, 41, 41, 4],
'resnet/block2': [2, 41, 41, 8],
'resnet/block3': [2, 41, 41, 16],
'resnet/block4': [2, 41, 41, 32]}
for endpoint, shape in endpoint_to_shape.iteritems():
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self):
"""Verify dense feature extraction with atrous convolution."""
nominal_stride = 32
for output_stride in [4, 8, 16, 32, None]:
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
with tf.Graph().as_default():
with self.test_session() as sess:
tf.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
output, _ = self._resnet_small(inputs,
None,
is_training=False,
global_pool=False,
output_stride=output_stride)
if output_stride is None:
factor = 1
else:
factor = nominal_stride // output_stride
output = resnet_utils.subsample(output, factor)
# Make the two networks use the same weights.
tf.get_variable_scope().reuse_variables()
# Feature extraction at the nominal network rate.
expected, _ = self._resnet_small(inputs,
None,
is_training=False,
global_pool=False)
sess.run(tf.global_variables_initializer())
self.assertAllClose(output.eval(), expected.eval(),
atol=1e-4, rtol=1e-4)
def testUnknownBatchSize(self):
batch = 2
height, width = 65, 65
global_pool = True
num_classes = 10
inputs = create_test_input(None, height, width, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(inputs,
num_classes,
global_pool=global_pool,
scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 1, 1, num_classes))
def testFullyConvolutionalUnknownHeightWidth(self):
batch = 2
height, width = 65, 65
global_pool = False
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small(inputs,
None,
global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 3, 3, 32))
def testAtrousFullyConvolutionalUnknownHeightWidth(self):
batch = 2
height, width = 65, 65
global_pool = False
output_stride = 8
inputs = create_test_input(batch, None, None, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small(inputs,
None,
global_pool=global_pool,
output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(),
[batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEquals(output.shape, (batch, 9, 9, 32))
if __name__ == '__main__':
tf.test.main()
......@@ -493,6 +493,73 @@ def xception_block(scope,
}] * num_units)
def xception_41(inputs,
num_classes=None,
is_training=True,
global_pool=True,
keep_prob=0.5,
output_stride=None,
regularize_depthwise=False,
multi_grid=None,
reuse=None,
scope='xception_41'):
"""Xception-41 model."""
blocks = [
xception_block('entry_flow/block1',
depth_list=[128, 128, 128],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('entry_flow/block2',
depth_list=[256, 256, 256],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('entry_flow/block3',
depth_list=[728, 728, 728],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('middle_flow/block1',
depth_list=[728, 728, 728],
skip_connection_type='sum',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=8,
stride=1),
xception_block('exit_flow/block1',
depth_list=[728, 1024, 1024],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('exit_flow/block2',
depth_list=[1536, 1536, 2048],
skip_connection_type='none',
activation_fn_in_separable_conv=True,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1,
unit_rate_list=multi_grid),
]
return xception(inputs,
blocks=blocks,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
keep_prob=keep_prob,
output_stride=output_stride,
reuse=reuse,
scope=scope)
def xception_65(inputs,
num_classes=None,
is_training=True,
......@@ -560,6 +627,87 @@ def xception_65(inputs,
scope=scope)
def xception_71(inputs,
num_classes=None,
is_training=True,
global_pool=True,
keep_prob=0.5,
output_stride=None,
regularize_depthwise=False,
multi_grid=None,
reuse=None,
scope='xception_71'):
"""Xception-71 model."""
blocks = [
xception_block('entry_flow/block1',
depth_list=[128, 128, 128],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('entry_flow/block2',
depth_list=[256, 256, 256],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1),
xception_block('entry_flow/block3',
depth_list=[256, 256, 256],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('entry_flow/block4',
depth_list=[728, 728, 728],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1),
xception_block('entry_flow/block5',
depth_list=[728, 728, 728],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('middle_flow/block1',
depth_list=[728, 728, 728],
skip_connection_type='sum',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=16,
stride=1),
xception_block('exit_flow/block1',
depth_list=[728, 1024, 1024],
skip_connection_type='conv',
activation_fn_in_separable_conv=False,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=2),
xception_block('exit_flow/block2',
depth_list=[1536, 1536, 2048],
skip_connection_type='none',
activation_fn_in_separable_conv=True,
regularize_depthwise=regularize_depthwise,
num_units=1,
stride=1,
unit_rate_list=multi_grid),
]
return xception(inputs,
blocks=blocks,
num_classes=num_classes,
is_training=is_training,
global_pool=global_pool,
keep_prob=keep_prob,
output_stride=output_stride,
reuse=reuse,
scope=scope)
def xception_arg_scope(weight_decay=0.00004,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001,
......
......@@ -14,8 +14,8 @@
# ==============================================================================
"""Tests for xception.py."""
import six
import numpy as np
import six
import tensorflow as tf
from deeplab.core import xception
......
......@@ -13,10 +13,11 @@
# limitations under the License.
# ==============================================================================
"""Converts ADE20K data to TFRecord file format with Example protos."""
import math
import os
import random
import string
import sys
import build_data
import tensorflow as tf
......@@ -44,12 +45,13 @@ tf.app.flags.DEFINE_string(
tf.app.flags.DEFINE_string(
'output_dir', './ADE20K/tfrecord',
'Path to save converted SSTable of Tensorflow example')
'Path to save converted tfrecord of Tensorflow example')
_NUM_SHARDS = 4
def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
""" Converts the ADE20k dataset into into tfrecord format (SSTable).
"""Converts the ADE20k dataset into into tfrecord format.
Args:
dataset_split: Dataset split (e.g., train, val).
......@@ -65,7 +67,7 @@ def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
seg_names = []
for f in img_names:
# get the filename without the extension
basename = os.path.basename(f).split(".")[0]
basename = os.path.basename(f).split('.')[0]
# cover its corresponding *_seg.png
seg = os.path.join(dataset_label_dir, basename+'.png')
seg_names.append(seg)
......@@ -104,10 +106,13 @@ def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
sys.stdout.write('\n')
sys.stdout.flush()
def main(unused_argv):
tf.gfile.MakeDirs(FLAGS.output_dir)
_convert_dataset('train', FLAGS.train_image_folder, FLAGS.train_image_label_folder)
_convert_dataset(
'train', FLAGS.train_image_folder, FLAGS.train_image_label_folder)
_convert_dataset('val', FLAGS.val_image_folder, FLAGS.val_image_label_folder)
if __name__ == '__main__':
tf.app.run()
......@@ -129,7 +129,8 @@ def _bytes_list_feature(values):
def norm2bytes(value):
return value.encode() if isinstance(value, str) and six.PY3 else value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
return tf.train.Feature(
bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
......
......@@ -69,7 +69,10 @@ _ITEMS_TO_DESCRIPTIONS = {
DatasetDescriptor = collections.namedtuple(
'DatasetDescriptor',
['splits_to_sizes', # Splits of the dataset into training, val, and test.
'num_classes', # Number of semantic classes.
'num_classes', # Number of semantic classes, including the background
# class (if exists). For example, there are 20
# foreground classes + 1 background class in the PASCAL
# VOC 2012 dataset. Thus, we set num_classes=21.
'ignore_label', # Ignore label value.
]
)
......@@ -96,12 +99,12 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
# These number (i.e., 'train'/'test') seems to have to be hard coded
# You are required to figure it out for your training/testing example.
_ADE20K_INFORMATION = DatasetDescriptor(
splits_to_sizes = {
splits_to_sizes={
'train': 20210, # num of samples in images/training
'val': 2000, # num of samples in images/validation
},
num_classes=150,
ignore_label=255,
num_classes=151,
ignore_label=0,
)
......
......@@ -17,8 +17,8 @@
See model.py for more details and usage.
"""
import six
import math
import six
import tensorflow as tf
from deeplab import common
from deeplab import model
......
......@@ -13,8 +13,7 @@ convert ADE20K semantic segmentation dataset to TFRecord.
bash download_and_convert_ade20k.sh
```
The converted dataset will be saved at
./deeplab/datasets/ADE20K/tfrecord
The converted dataset will be saved at ./deeplab/datasets/ADE20K/tfrecord
## Recommended Directory Structure for Training and Evaluation
......@@ -50,7 +49,7 @@ A local training job using `xception_65` can be run with the following command:
# From tensorflow/models/research/
python deeplab/train.py \
--logtostderr \
--training_number_of_steps=50000 \
--training_number_of_steps=90000 \
--train_split="train" \
--model_variant="xception_65" \
--atrous_rates=6 \
......@@ -61,21 +60,16 @@ python deeplab/train.py \
--train_crop_size=513 \
--train_crop_size=513 \
--train_batch_size=4 \
--min_resize_value=350 \
--max_resize_value=500 \
--min_resize_value=513 \
--max_resize_value=513 \
--resize_factor=16 \
--fine_tune_batch_norm=False \
--dataset="ade20k" \
--initialize_last_layer=False \
--last_layers_contain_logits_only=True \
--tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
--train_logdir=${PATH_TO_TRAIN_DIR}\
--dataset_dir=${PATH_TO_DATASET}
```
where ${PATH\_TO\_INITIAL\_CHECKPOINT} is the path to the initial checkpoint.
For example, if you are using the deeplabv3\_pascal\_train\_aug checkppoint, you
will set this to `/path/to/deeplabv3\_pascal\_train\_aug/model.ckpt`.
${PATH\_TO\_TRAIN\_DIR} is the directory in which training checkpoints and
events will be written to (it is recommended to set it to the
`train_on_train_set/train` above), and ${PATH\_TO\_DATASET} is the directory in
......@@ -83,24 +77,22 @@ which the ADE20K dataset resides (the `tfrecord` above)
**Note that for train.py:**
1. In order to fine tune the BN layers, one needs to use large batch size (> 12),
and set fine_tune_batch_norm = True. Here, we simply use small batch size
during training for the purpose of demonstration. If the users have limited
GPU memory at hand, please fine-tune from our provided checkpoints whose
batch norm parameters have been trained, and use smaller learning rate with
fine_tune_batch_norm = False.
1. In order to fine tune the BN layers, one needs to use large batch size (>
12), and set fine_tune_batch_norm = True. Here, we simply use small batch
size during training for the purpose of demonstration. If the users have
limited GPU memory at hand, please fine-tune from our provided checkpoints
whose batch norm parameters have been trained, and use smaller learning rate
with fine_tune_batch_norm = False.
2. User should fine tune the `min_resize_value` and `max_resize_value` to get
better result. Note that `resize_factor` has to be equal to `output_stride`.
2. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
3. The users should change atrous_rates from [6, 12, 18] to [12, 24, 36] if
setting output_stride=8.
3. The users could skip the flag, `decoder_output_stride`, if you do not want
4. The users could skip the flag, `decoder_output_stride`, if you do not want
to use the decoder structure.
Currently there are no fine-tuned checkpoint for the ADE20K dataset.
## Running Tensorboard
Progress for training and evaluation jobs can be inspected using Tensorboard. If
......
......@@ -60,6 +60,12 @@ sh local_test_mobilenetv2.sh
First, make sure you could reproduce the results with our provided setting.
After that, you could start to make a new change one at a time to help debug.
___
Q8: What value of `eval_crop_size` should I use?
A: Our model uses whole-image inference, meaning that we need to set `eval_crop_size` equal to `output_stride` * k + 1, where k is an integer and set k so that the resulting `eval_crop_size` is slightly larger the largest
image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
image dimension is all equal to 1024x2048.
___
## References
......
# TensorFlow DeepLab Model Zoo
We provide deeplab models pretrained on PASCAL VOC 2012 and Cityscapes datasets
for reproducing our results, as well as some checkpoints that are only
pretrained on ImageNet for training your own models.
We provide deeplab models pretrained several datasets, including (1) PASCAL VOC
2012, (2) Cityscapes, and (3) ADE20K for reproducing our results, as well as
some checkpoints that are only pretrained on ImageNet for training your own
models.
## DeepLab models trained on PASCAL VOC 2012
......@@ -69,6 +70,22 @@ Checkpoint name
[mobilenetv2_coco_cityscapes_trainfine](http://download.tensorflow.org/models/deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz) | 16 <br> 8 | [1.0] <br> [0.75:0.25:1.25] | No <br> Yes | 21.27B <br> 433.24B | 0.8 <br> 51.12 | 70.71% (val) <br> 73.57% (val) | 23MB
[xception_cityscapes_trainfine](http://download.tensorflow.org/models/deeplabv3_cityscapes_train_2018_02_06.tar.gz) | 16 <br> 8 | [1.0] <br> [0.75:0.25:1.25] | No <br> Yes | 418.64B <br> 8677.92B | 5.0 <br> 422.8 | 78.79% (val) <br> 80.42% (val) | 439MB
## DeepLab models trained on ADE20K
### Model details
We provide some checkpoints that have been pretrained on ADE20K training set.
Note that the model has only been pretrained on ImageNet, following the
dataset rule.
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----:
xception_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4
Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
-------------------------------------------------------------------------------------------------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :------------: | :----------------------------: | :-------:
[xception_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_14.tar.gz) | 16 | [0.5:0.25:1.75] | Yes | 43.54% (val) | 81.74% (val) | 439MB
## Checkpoints pretrained on ImageNet
Un-tar'ed directory includes:
......@@ -84,15 +101,24 @@ one could use this for training your own models.
[MobileNet-V2](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet)
for details.
* xception: We adapt the original Xception model to the task of semantic
segmentation with the following changes: (1) more layers, (2) all max
pooling operations are replaced by strided (atrous) separable convolutions,
and (3) extra batch-norm and ReLU after each 3x3 depthwise convolution are
added.
* xception_{41,65,71}: We adapt the original Xception model to the task of
semantic segmentation with the following changes: (1) more layers, (2) all
max pooling operations are replaced by strided (atrous) separable
convolutions, and (3) extra batch-norm and ReLU after each 3x3 depthwise
convolution are added. We provide three Xception model variants with
different network depths.
* resnet_v1_{50,101}_beta: We modify the original ResNet-101 [10], similar to
PSPNet [11] by replacing the first 7x7 convolution with three 3x3
convolutions. See resnet_v1_beta.py for more details.
Model name | File Size
-------------------------------------------------------------------------------------- | :-------:
[xception](http://download.tensorflow.org/models/deeplabv3_xception_2018_01_04.tar.gz) | 447MB
[xception_41](http://download.tensorflow.org/models/xception_41_2018_05_09.tar.gz ) | 288MB
[xception_65](http://download.tensorflow.org/models/deeplabv3_xception_2018_01_04.tar.gz) | 447MB
[xception_71](http://download.tensorflow.org/models/xception_71_2018_05_09.tar.gz ) | 474MB
[resnet_v1_50_beta](http://download.tensorflow.org/models/resnet_v1_50_2018_05_04.tar.gz) | 274MB
[resnet_v1_101_beta](http://download.tensorflow.org/models/resnet_v1_101_2018_05_04.tar.gz) | 477MB
## References
......@@ -132,3 +158,16 @@ Model name
9. **ImageNet Large Scale Visual Recognition Challenge**<br />
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, Li Fei-Fei<br />
[[link]](http://www.image-net.org/). IJCV, 2015.
10. **Deep Residual Learning for Image Recognition**<br />
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun<br />
[[link]](https://arxiv.org/abs/1512.03385). CVPR, 2016.
11. **Pyramid Scene Parsing Network**<br />
Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia<br />
[[link]](https://arxiv.org/abs/1612.01105). In CVPR, 2017.
12. **Scene Parsing through ADE20K Dataset**<br />
Bolei Zhou, Hang Zhao, Xavier Puig, Sanja Fidler, Adela Barriuso, Antonio Torralba<br />
[[link]](http://groups.csail.mit.edu/vision/datasets/ADE20K/). In CVPR,
2017.
......@@ -64,6 +64,10 @@ _CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder'
def get_merged_logits_scope():
return _MERGED_LOGITS_SCOPE
def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers.
......@@ -358,6 +362,7 @@ def _extract_features(images,
output_stride=model_options.output_stride,
multi_grid=model_options.multi_grid,
model_variant=model_options.model_variant,
depth_multiplier=model_options.depth_multiplier,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
......
......@@ -111,7 +111,7 @@ class DeeplabModelTest(tf.test.TestCase):
for output in outputs_to_num_classes:
scales_to_logits = outputs_to_scales_to_logits[output]
# Expect only one output.
self.assertEquals(len(scales_to_logits), 1)
self.assertEqual(len(scales_to_logits), 1)
for logits in scales_to_logits.values():
self.assertTrue(logits.any())
......
......@@ -68,7 +68,8 @@ flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
flags.DEFINE_boolean('save_summaries_images', False,
'Save sample inputs, labels, and semantic predictions as images to summary.')
'Save sample inputs, labels, and semantic predictions as '
'images to summary.')
# Settings for training strategy.
......@@ -184,9 +185,11 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
"""
samples = inputs_queue.dequeue()
# add name to input and label nodes so we can add to summary
samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name = common.IMAGE)
samples[common.LABEL] = tf.identity(samples[common.LABEL], name = common.LABEL)
# Add name to input and label nodes so we can add to summary.
samples[common.IMAGE] = tf.identity(
samples[common.IMAGE], name=common.IMAGE)
samples[common.LABEL] = tf.identity(
samples[common.LABEL], name=common.LABEL)
model_options = common.ModelOptions(
outputs_to_num_classes=outputs_to_num_classes,
......@@ -201,11 +204,11 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
# add name to graph node so we can add to summary
outputs_to_scales_to_logits[common.OUTPUT_TYPE][model._MERGED_LOGITS_SCOPE] = tf.identity(
outputs_to_scales_to_logits[common.OUTPUT_TYPE][model._MERGED_LOGITS_SCOPE],
name = common.OUTPUT_TYPE
)
# Add name to graph node so we can add to summary.
output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
output_type_dict[model.get_merged_logits_scope()] = tf.identity(
output_type_dict[model.get_merged_logits_scope()],
name=common.OUTPUT_TYPE)
for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale(
......@@ -234,7 +237,7 @@ def main(unused_argv):
assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)
clone_batch_size = FLAGS.train_batch_size // config.num_clones
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset(
......@@ -288,17 +291,25 @@ def main(unused_argv):
if FLAGS.save_summaries_images:
summary_image = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
summaries.add(tf.summary.image('samples/%s' % common.IMAGE, summary_image))
summary_label = tf.cast(graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')),
tf.uint8)
summaries.add(tf.summary.image('samples/%s' % common.LABEL, summary_label))
predictions = tf.cast(tf.expand_dims(tf.argmax(graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')),
3), -1), tf.uint8)
summaries.add(tf.summary.image('samples/%s' % common.OUTPUT_TYPE, predictions))
summaries.add(
tf.summary.image('samples/%s' % common.IMAGE, summary_image))
first_clone_label = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
# Scale up summary image pixel values for better visualization.
pixel_scaling = max(1, 255 // dataset.num_classes)
summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image('samples/%s' % common.LABEL, summary_label))
first_clone_output = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image(
'samples/%s' % common.OUTPUT_TYPE, summary_predictions))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
......@@ -325,7 +336,8 @@ def main(unused_argv):
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(FLAGS.last_layers_contain_logits_only)
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult:
......
......@@ -17,30 +17,196 @@
Visualizes the semantic segmentation results by the color map
defined by the different datasets. Supported colormaps are:
1. PASCAL VOC semantic segmentation benchmark.
Website: http://host.robots.ox.ac.uk/pascal/VOC/
* ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/).
* Cityscapes dataset (https://www.cityscapes-dataset.com).
* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
"""
import numpy as np
# Dataset names.
_ADE20K = 'ade20k'
_CITYSCAPES = 'cityscapes'
_PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = {
_ADE20K: 151,
_CITYSCAPES: 19,
_PASCAL: 256,
}
def create_ade20k_label_colormap():
"""Creates a label colormap used in ADE20K segmentation benchmark.
Returns:
A colormap for visualizing segmentation results.
"""
return np.asarray([
[0, 0, 0],
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
])
def create_cityscapes_label_colormap():
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
A colormap for visualizing segmentation results.
"""
colormap = np.asarray([
return np.asarray([
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
......@@ -61,17 +227,37 @@ def create_cityscapes_label_colormap():
[0, 0, 230],
[119, 11, 32],
])
def create_pascal_label_colormap():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A colormap for visualizing segmentation results.
"""
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= bit_get(ind, channel) << shift
ind >>= 3
return colormap
def get_pascal_name():
return _PASCAL
def get_ade20k_name():
return _ADE20K
def get_cityscapes_name():
return _CITYSCAPES
def get_pascal_name():
return _PASCAL
def bit_get(val, idx):
"""Gets the bit value.
......@@ -85,23 +271,6 @@ def bit_get(val, idx):
return (val >> idx) & 1
def create_pascal_label_colormap():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= bit_get(ind, channel) << shift
ind >>= 3
return colormap
def create_label_colormap(dataset=_PASCAL):
"""Creates a label colormap for the specified dataset.
......@@ -114,10 +283,12 @@ def create_label_colormap(dataset=_PASCAL):
Raises:
ValueError: If the dataset is not supported.
"""
if dataset == _PASCAL:
return create_pascal_label_colormap()
if dataset == _ADE20K:
return create_ade20k_label_colormap()
elif dataset == _CITYSCAPES:
return create_cityscapes_label_colormap()
elif dataset == _PASCAL:
return create_pascal_label_colormap()
else:
raise ValueError('Unsupported dataset.')
......@@ -132,7 +303,7 @@ def label_to_color_image(label, dataset=_PASCAL):
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
to the dataset color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
......
......@@ -70,6 +70,22 @@ class VisualizationUtilTest(tf.test.TestCase):
with self.assertRaises(ValueError):
get_dataset_colormap.create_label_colormap('unsupported_dataset')
def testUnExpectedLabelDimensionForLabelToADE20KColorImage(self):
label = np.array([250])
with self.assertRaises(ValueError):
get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_ade20k_name())
def testFirstColorInADE20KColorMap(self):
label = np.array([[1, 3], [10, 20]])
expected_result = np.array([
[[120, 120, 120], [6, 230, 230]],
[[4, 250, 7], [204, 70, 3]]
])
colored_label = get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_ade20k_name())
self.assertTrue(np.array_equal(colored_label, expected_result))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment