Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
05ccaf88
Unverified
Commit
05ccaf88
authored
Mar 08, 2018
by
Lukasz Kaiser
Committed by
GitHub
Mar 08, 2018
Browse files
Merge pull request #3521 from YknZhu/master
Add deeplab model in tensorflow models
parents
6571d16d
1e9b07d8
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
965 additions
and
0 deletions
+965
-0
research/deeplab/utils/get_dataset_colormap.py
research/deeplab/utils/get_dataset_colormap.py
+148
-0
research/deeplab/utils/get_dataset_colormap_test.py
research/deeplab/utils/get_dataset_colormap_test.py
+75
-0
research/deeplab/utils/input_generator.py
research/deeplab/utils/input_generator.py
+168
-0
research/deeplab/utils/save_annotation.py
research/deeplab/utils/save_annotation.py
+52
-0
research/deeplab/utils/train_utils.py
research/deeplab/utils/train_utils.py
+202
-0
research/deeplab/vis.py
research/deeplab/vis.py
+320
-0
No files found.
research/deeplab/utils/get_dataset_colormap.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualizes the segmentation results via specified color map.
Visualizes the semantic segmentation results by the color map
defined by the different datasets. Supported colormaps are:
1. PASCAL VOC semantic segmentation benchmark.
Website: http://host.robots.ox.ac.uk/pascal/VOC/
"""
import
numpy
as
np
# Dataset names.
_CITYSCAPES
=
'cityscapes'
_PASCAL
=
'pascal'
# Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES
=
{
_CITYSCAPES
:
19
,
_PASCAL
:
256
,
}
def
create_cityscapes_label_colormap
():
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap
=
np
.
asarray
([
[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
],
])
return
colormap
def
get_pascal_name
():
return
_PASCAL
def
get_cityscapes_name
():
return
_CITYSCAPES
def
bit_get
(
val
,
idx
):
"""Gets the bit value.
Args:
val: Input value, int or numpy int array.
idx: Which bit of the input val.
Returns:
The "idx"-th bit of input val.
"""
return
(
val
>>
idx
)
&
1
def
create_pascal_label_colormap
():
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap
=
np
.
zeros
((
_DATASET_MAX_ENTRIES
[
_PASCAL
],
3
),
dtype
=
int
)
ind
=
np
.
arange
(
_DATASET_MAX_ENTRIES
[
_PASCAL
],
dtype
=
int
)
for
shift
in
reversed
(
range
(
8
)):
for
channel
in
range
(
3
):
colormap
[:,
channel
]
|=
bit_get
(
ind
,
channel
)
<<
shift
ind
>>=
3
return
colormap
def
create_label_colormap
(
dataset
=
_PASCAL
):
"""Creates a label colormap for the specified dataset.
Args:
dataset: The colormap used in the dataset.
Returns:
A numpy array of the dataset colormap.
Raises:
ValueError: If the dataset is not supported.
"""
if
dataset
==
_PASCAL
:
return
create_pascal_label_colormap
()
elif
dataset
==
_CITYSCAPES
:
return
create_cityscapes_label_colormap
()
else
:
raise
ValueError
(
'Unsupported dataset.'
)
def
label_to_color_image
(
label
,
dataset
=
_PASCAL
):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
dataset: The colormap used in the dataset.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if
label
.
ndim
!=
2
:
raise
ValueError
(
'Expect 2-D input label'
)
if
np
.
max
(
label
)
>=
_DATASET_MAX_ENTRIES
[
dataset
]:
raise
ValueError
(
'label value too large.'
)
colormap
=
create_label_colormap
(
dataset
)
return
colormap
[
label
]
research/deeplab/utils/get_dataset_colormap_test.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for get_dataset_colormap.py."""
import
numpy
as
np
import
tensorflow
as
tf
from
deeplab.utils
import
get_dataset_colormap
class
VisualizationUtilTest
(
tf
.
test
.
TestCase
):
def
testBitGet
(
self
):
"""Test that if the returned bit value is correct."""
self
.
assertEqual
(
1
,
get_dataset_colormap
.
bit_get
(
9
,
0
))
self
.
assertEqual
(
0
,
get_dataset_colormap
.
bit_get
(
9
,
1
))
self
.
assertEqual
(
0
,
get_dataset_colormap
.
bit_get
(
9
,
2
))
self
.
assertEqual
(
1
,
get_dataset_colormap
.
bit_get
(
9
,
3
))
def
testPASCALLabelColorMapValue
(
self
):
"""Test the getd color map value."""
colormap
=
get_dataset_colormap
.
create_pascal_label_colormap
()
# Only test a few sampled entries in the color map.
self
.
assertTrue
(
np
.
array_equal
([
128.
,
0.
,
128.
],
colormap
[
5
,
:]))
self
.
assertTrue
(
np
.
array_equal
([
128.
,
192.
,
128.
],
colormap
[
23
,
:]))
self
.
assertTrue
(
np
.
array_equal
([
128.
,
0.
,
192.
],
colormap
[
37
,
:]))
self
.
assertTrue
(
np
.
array_equal
([
224.
,
192.
,
192.
],
colormap
[
127
,
:]))
self
.
assertTrue
(
np
.
array_equal
([
192.
,
160.
,
192.
],
colormap
[
175
,
:]))
def
testLabelToPASCALColorImage
(
self
):
"""Test the value of the converted label value."""
label
=
np
.
array
([[
0
,
16
,
16
],
[
52
,
7
,
52
]])
expected_result
=
np
.
array
([
[[
0
,
0
,
0
],
[
0
,
64
,
0
],
[
0
,
64
,
0
]],
[[
0
,
64
,
192
],
[
128
,
128
,
128
],
[
0
,
64
,
192
]]
])
colored_label
=
get_dataset_colormap
.
label_to_color_image
(
label
,
get_dataset_colormap
.
get_pascal_name
())
self
.
assertTrue
(
np
.
array_equal
(
expected_result
,
colored_label
))
def
testUnExpectedLabelValueForLabelToPASCALColorImage
(
self
):
"""Raise ValueError when input value exceeds range."""
label
=
np
.
array
([[
120
],
[
300
]])
with
self
.
assertRaises
(
ValueError
):
get_dataset_colormap
.
label_to_color_image
(
label
,
get_dataset_colormap
.
get_pascal_name
())
def
testUnExpectedLabelDimensionForLabelToPASCALColorImage
(
self
):
"""Raise ValueError if input dimension is not correct."""
label
=
np
.
array
([
120
])
with
self
.
assertRaises
(
ValueError
):
get_dataset_colormap
.
label_to_color_image
(
label
,
get_dataset_colormap
.
get_pascal_name
())
def
testGetColormapForUnsupportedDataset
(
self
):
with
self
.
assertRaises
(
ValueError
):
get_dataset_colormap
.
create_label_colormap
(
'unsupported_dataset'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/deeplab/utils/input_generator.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import
tensorflow
as
tf
from
deeplab
import
common
from
deeplab
import
input_preprocess
slim
=
tf
.
contrib
.
slim
dataset_data_provider
=
slim
.
dataset_data_provider
def
_get_data
(
data_provider
,
dataset_split
):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if
common
.
LABELS_CLASS
not
in
data_provider
.
list_items
():
raise
ValueError
(
'Failed to find labels.'
)
image
,
height
,
width
=
data_provider
.
get
(
[
common
.
IMAGE
,
common
.
HEIGHT
,
common
.
WIDTH
])
# Some datasets do not contain image_name.
if
common
.
IMAGE_NAME
in
data_provider
.
list_items
():
image_name
,
=
data_provider
.
get
([
common
.
IMAGE_NAME
])
else
:
image_name
=
tf
.
constant
(
''
)
label
=
None
if
dataset_split
!=
common
.
TEST_SET
:
label
,
=
data_provider
.
get
([
common
.
LABELS_CLASS
])
return
image
,
label
,
image_name
,
height
,
width
def
get
(
dataset
,
crop_size
,
batch_size
,
min_resize_value
=
None
,
max_resize_value
=
None
,
resize_factor
=
None
,
min_scale_factor
=
1.
,
max_scale_factor
=
1.
,
scale_factor_step_size
=
0
,
num_readers
=
1
,
num_threads
=
1
,
dataset_split
=
None
,
is_training
=
True
,
model_variant
=
None
):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if
dataset_split
is
None
:
raise
ValueError
(
'Unknown dataset split.'
)
if
model_variant
is
None
:
tf
.
logging
.
warning
(
'Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.'
)
data_provider
=
dataset_data_provider
.
DatasetDataProvider
(
dataset
,
num_readers
=
num_readers
,
num_epochs
=
None
if
is_training
else
1
,
shuffle
=
is_training
)
image
,
label
,
image_name
,
height
,
width
=
_get_data
(
data_provider
,
dataset_split
)
if
label
is
not
None
:
if
label
.
shape
.
ndims
==
2
:
label
=
tf
.
expand_dims
(
label
,
2
)
elif
label
.
shape
.
ndims
==
3
and
label
.
shape
.
dims
[
2
]
==
1
:
pass
else
:
raise
ValueError
(
'Input label shape must be [height, width], or '
'[height, width, 1].'
)
label
.
set_shape
([
None
,
None
,
1
])
original_image
,
image
,
label
=
input_preprocess
.
preprocess_image_and_label
(
image
,
label
,
crop_height
=
crop_size
[
0
],
crop_width
=
crop_size
[
1
],
min_resize_value
=
min_resize_value
,
max_resize_value
=
max_resize_value
,
resize_factor
=
resize_factor
,
min_scale_factor
=
min_scale_factor
,
max_scale_factor
=
max_scale_factor
,
scale_factor_step_size
=
scale_factor_step_size
,
ignore_label
=
dataset
.
ignore_label
,
is_training
=
is_training
,
model_variant
=
model_variant
)
sample
=
{
common
.
IMAGE
:
image
,
common
.
IMAGE_NAME
:
image_name
,
common
.
HEIGHT
:
height
,
common
.
WIDTH
:
width
}
if
label
is
not
None
:
sample
[
common
.
LABEL
]
=
label
if
not
is_training
:
# Original image is only used during visualization.
sample
[
common
.
ORIGINAL_IMAGE
]
=
original_image
,
num_threads
=
1
return
tf
.
train
.
batch
(
sample
,
batch_size
=
batch_size
,
num_threads
=
num_threads
,
capacity
=
32
*
batch_size
,
allow_smaller_final_batch
=
not
is_training
,
dynamic_pad
=
True
)
research/deeplab/utils/save_annotation.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Saves an annotation as one png image.
This script saves an annotation as one png image, and has the option to add
colormap to the png image for better visualization.
"""
import
numpy
as
np
import
PIL.Image
as
img
import
tensorflow
as
tf
from
deeplab.utils
import
get_dataset_colormap
def
save_annotation
(
label
,
save_dir
,
filename
,
add_colormap
=
True
,
colormap_type
=
get_dataset_colormap
.
get_pascal_name
()):
"""Saves the given label to image on disk.
Args:
label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image.
save_dir: The directory to which the results will be saved.
filename: The image filename.
add_colormap: Add color map to the label or not.
colormap_type: Colormap type for visualization.
"""
# Add colormap for visualizing the prediction.
if
add_colormap
:
colored_label
=
get_dataset_colormap
.
label_to_color_image
(
label
,
colormap_type
)
else
:
colored_label
=
label
pil_image
=
img
.
fromarray
(
colored_label
.
astype
(
dtype
=
np
.
uint8
))
with
tf
.
gfile
.
Open
(
'%s/%s.png'
%
(
save_dir
,
filename
),
mode
=
'w'
)
as
f
:
pil_image
.
save
(
f
,
'PNG'
)
research/deeplab/utils/train_utils.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for training."""
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
def
add_softmax_cross_entropy_loss_for_each_scale
(
scales_to_logits
,
labels
,
num_classes
,
ignore_label
,
loss_weight
=
1.0
,
upsample_logits
=
True
,
scope
=
None
):
"""Adds softmax cross entropy loss for logits of each scale.
Args:
scales_to_logits: A map from logits names for different scales to logits.
The logits have shape [batch, logits_height, logits_width, num_classes].
labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
num_classes: Integer, number of target classes.
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not.
scope: String, the scope for the loss.
Raises:
ValueError: Label or logits is None.
"""
if
labels
is
None
:
raise
ValueError
(
'No label for softmax cross entropy loss.'
)
for
scale
,
logits
in
scales_to_logits
.
iteritems
():
loss_scope
=
None
if
scope
:
loss_scope
=
'%s_%s'
%
(
scope
,
scale
)
if
upsample_logits
:
# Label is not downsampled, and instead we upsample logits.
logits
=
tf
.
image
.
resize_bilinear
(
logits
,
tf
.
shape
(
labels
)[
1
:
3
],
align_corners
=
True
)
scaled_labels
=
labels
else
:
# Label is downsampled to the same size as logits.
scaled_labels
=
tf
.
image
.
resize_nearest_neighbor
(
labels
,
tf
.
shape
(
logits
)[
1
:
3
],
align_corners
=
True
)
scaled_labels
=
tf
.
reshape
(
scaled_labels
,
shape
=
[
-
1
])
not_ignore_mask
=
tf
.
to_float
(
tf
.
not_equal
(
scaled_labels
,
ignore_label
))
*
loss_weight
one_hot_labels
=
slim
.
one_hot_encoding
(
scaled_labels
,
num_classes
,
on_value
=
1.0
,
off_value
=
0.0
)
tf
.
losses
.
softmax_cross_entropy
(
one_hot_labels
,
tf
.
reshape
(
logits
,
shape
=
[
-
1
,
num_classes
]),
weights
=
not_ignore_mask
,
scope
=
loss_scope
)
def
get_model_init_fn
(
train_logdir
,
tf_initial_checkpoint
,
initialize_last_layer
,
last_layers
,
ignore_missing_vars
=
False
):
"""Gets the function initializing model variables from a checkpoint.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if
tf_initial_checkpoint
is
None
:
tf
.
logging
.
info
(
'Not initializing the model from a checkpoint.'
)
return
None
if
tf
.
train
.
latest_checkpoint
(
train_logdir
):
tf
.
logging
.
info
(
'Ignoring initialization; other checkpoint exists'
)
return
None
tf
.
logging
.
info
(
'Initializing model from path: %s'
,
tf_initial_checkpoint
)
# Variables that will not be restored.
exclude_list
=
[
'global_step'
]
if
not
initialize_last_layer
:
exclude_list
.
extend
(
last_layers
)
variables_to_restore
=
slim
.
get_variables_to_restore
(
exclude
=
exclude_list
)
return
slim
.
assign_from_checkpoint_fn
(
tf_initial_checkpoint
,
variables_to_restore
,
ignore_missing_vars
=
ignore_missing_vars
)
def
get_model_gradient_multipliers
(
last_layers
,
last_layer_gradient_multiplier
):
"""Gets the gradient multipliers.
The gradient multipliers will adjust the learning rates for model
variables. For the task of semantic segmentation, the models are
usually fine-tuned from the models trained on the task of image
classification. To fine-tune the models, we usually set larger (e.g.,
10 times larger) learning rate for the parameters of last layer.
Args:
last_layers: Scopes of last layers.
last_layer_gradient_multiplier: The gradient multiplier for last layers.
Returns:
The gradient multiplier map with variables as key, and multipliers as value.
"""
gradient_multipliers
=
{}
for
var
in
slim
.
get_model_variables
():
# Double the learning rate for biases.
if
'biases'
in
var
.
op
.
name
:
gradient_multipliers
[
var
.
op
.
name
]
=
2.
# Use larger learning rate for last layer variables.
for
layer
in
last_layers
:
if
layer
in
var
.
op
.
name
and
'biases'
in
var
.
op
.
name
:
gradient_multipliers
[
var
.
op
.
name
]
=
2
*
last_layer_gradient_multiplier
break
elif
layer
in
var
.
op
.
name
:
gradient_multipliers
[
var
.
op
.
name
]
=
last_layer_gradient_multiplier
break
return
gradient_multipliers
def
get_model_learning_rate
(
learning_policy
,
base_learning_rate
,
learning_rate_decay_step
,
learning_rate_decay_factor
,
training_number_of_steps
,
learning_power
,
slow_start_step
,
slow_start_learning_rate
):
"""Gets model's learning rate.
Computes the model's learning rate for different learning policy.
Right now, only "step" and "poly" are supported.
(1) The learning policy for "step" is computed as follows:
current_learning_rate = base_learning_rate *
learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
See tf.train.exponential_decay for details.
(2) The learning policy for "poly" is computed as follows:
current_learning_rate = base_learning_rate *
(1 - global_step / training_number_of_steps) ^ learning_power
Args:
learning_policy: Learning rate policy for training.
base_learning_rate: The base learning rate for model training.
learning_rate_decay_step: Decay the base learning rate at a fixed step.
learning_rate_decay_factor: The rate to decay the base learning rate.
training_number_of_steps: Number of steps for training.
learning_power: Power used for 'poly' learning policy.
slow_start_step: Training model with small learning rate for the first
few steps.
slow_start_learning_rate: The learning rate employed during slow start.
Returns:
Learning rate for the specified learning policy.
Raises:
ValueError: If learning policy is not recognized.
"""
global_step
=
tf
.
train
.
get_or_create_global_step
()
if
learning_policy
==
'step'
:
learning_rate
=
tf
.
train
.
exponential_decay
(
base_learning_rate
,
global_step
,
learning_rate_decay_step
,
learning_rate_decay_factor
,
staircase
=
True
)
elif
learning_policy
==
'poly'
:
learning_rate
=
tf
.
train
.
polynomial_decay
(
base_learning_rate
,
global_step
,
training_number_of_steps
,
end_learning_rate
=
0
,
power
=
learning_power
)
else
:
raise
ValueError
(
'Unknown learning policy.'
)
# Employ small learning rate at the first few steps for warm start.
return
tf
.
where
(
global_step
<
slow_start_step
,
slow_start_learning_rate
,
learning_rate
)
research/deeplab/vis.py
0 → 100644
View file @
05ccaf88
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Segmentation results visualization on a given set of images.
See model.py for more details and usage.
"""
import
math
import
os.path
import
time
import
numpy
as
np
import
tensorflow
as
tf
from
deeplab
import
common
from
deeplab
import
model
from
deeplab.datasets
import
segmentation_dataset
from
deeplab.utils
import
input_generator
from
deeplab.utils
import
save_annotation
slim
=
tf
.
contrib
.
slim
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master'
,
''
,
'BNS name of the tensorflow server'
)
# Settings for log directories.
flags
.
DEFINE_string
(
'vis_logdir'
,
None
,
'Where to write the event logs.'
)
flags
.
DEFINE_string
(
'checkpoint_dir'
,
None
,
'Directory of model checkpoints.'
)
# Settings for visualizing the model.
flags
.
DEFINE_integer
(
'vis_batch_size'
,
1
,
'The number of images in each batch during evaluation.'
)
flags
.
DEFINE_multi_integer
(
'vis_crop_size'
,
[
513
,
513
],
'Crop size [height, width] for visualization.'
)
flags
.
DEFINE_integer
(
'eval_interval_secs'
,
60
*
5
,
'How often (in seconds) to run evaluation.'
)
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. Note one could use different
# atrous_rates/output_stride during training/evaluation.
flags
.
DEFINE_multi_integer
(
'atrous_rates'
,
None
,
'Atrous rates for atrous spatial pyramid pooling.'
)
flags
.
DEFINE_integer
(
'output_stride'
,
16
,
'The ratio of input to output spatial resolution.'
)
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test.
flags
.
DEFINE_multi_float
(
'eval_scales'
,
[
1.0
],
'The scales to resize images for evaluation.'
)
# Change to True for adding flipped images during test.
flags
.
DEFINE_bool
(
'add_flipped_images'
,
False
,
'Add flipped images for evaluation or not.'
)
# Dataset settings.
flags
.
DEFINE_string
(
'dataset'
,
'pascal_voc_seg'
,
'Name of the segmentation dataset.'
)
flags
.
DEFINE_string
(
'vis_split'
,
'val'
,
'Which split of the dataset used for visualizing results'
)
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'Where the dataset reside.'
)
flags
.
DEFINE_enum
(
'colormap_type'
,
'pascal'
,
[
'pascal'
,
'cityscapes'
],
'Visualization colormap type.'
)
flags
.
DEFINE_boolean
(
'also_save_raw_predictions'
,
False
,
'Also save raw predictions.'
)
flags
.
DEFINE_integer
(
'max_number_of_iterations'
,
0
,
'Maximum number of visualization iterations. Will loop '
'indefinitely upon nonpositive values.'
)
# The folder where semantic segmentation predictions are saved.
_SEMANTIC_PREDICTION_SAVE_FOLDER
=
'segmentation_results'
# The folder where raw semantic segmentation predictions are saved.
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER
=
'raw_segmentation_results'
# The format to save image.
_IMAGE_FORMAT
=
'%06d_image'
# The format to save prediction
_PREDICTION_FORMAT
=
'%06d_prediction'
# To evaluate Cityscapes results on the evaluation server, the labels used
# during training should be mapped to the labels for evaluation.
_CITYSCAPES_TRAIN_ID_TO_EVAL_ID
=
[
7
,
8
,
11
,
12
,
13
,
17
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
31
,
32
,
33
]
def
_convert_train_id_to_eval_id
(
prediction
,
train_id_to_eval_id
):
"""Converts the predicted label for evaluation.
There are cases where the training labels are not equal to the evaluation
labels. This function is used to perform the conversion so that we could
evaluate the results on the evaluation server.
Args:
prediction: Semantic segmentation prediction.
train_id_to_eval_id: A list mapping from train id to evaluation id.
Returns:
Semantic segmentation prediction whose labels have been changed.
"""
converted_prediction
=
prediction
.
copy
()
for
train_id
,
eval_id
in
enumerate
(
train_id_to_eval_id
):
converted_prediction
[
prediction
==
train_id
]
=
eval_id
return
converted_prediction
def
_process_batch
(
sess
,
original_images
,
semantic_predictions
,
image_names
,
image_heights
,
image_widths
,
image_id_offset
,
save_dir
,
raw_save_dir
,
train_id_to_eval_id
=
None
):
"""Evaluates one single batch qualitatively.
Args:
sess: TensorFlow session.
original_images: One batch of original images.
semantic_predictions: One batch of semantic segmentation predictions.
image_names: Image names.
image_heights: Image heights.
image_widths: Image widths.
image_id_offset: Image id offset for indexing images.
save_dir: The directory where the predictions will be saved.
raw_save_dir: The directory where the raw predictions will be saved.
train_id_to_eval_id: A list mapping from train id to eval id.
"""
(
original_images
,
semantic_predictions
,
image_names
,
image_heights
,
image_widths
)
=
sess
.
run
([
original_images
,
semantic_predictions
,
image_names
,
image_heights
,
image_widths
])
num_image
=
semantic_predictions
.
shape
[
0
]
for
i
in
range
(
num_image
):
image_height
=
np
.
squeeze
(
image_heights
[
i
])
image_width
=
np
.
squeeze
(
image_widths
[
i
])
original_image
=
np
.
squeeze
(
original_images
[
i
])
semantic_prediction
=
np
.
squeeze
(
semantic_predictions
[
i
])
crop_semantic_prediction
=
semantic_prediction
[:
image_height
,
:
image_width
]
# Save image.
save_annotation
.
save_annotation
(
original_image
,
save_dir
,
_IMAGE_FORMAT
%
(
image_id_offset
+
i
),
add_colormap
=
False
)
# Save prediction.
save_annotation
.
save_annotation
(
crop_semantic_prediction
,
save_dir
,
_PREDICTION_FORMAT
%
(
image_id_offset
+
i
),
add_colormap
=
True
,
colormap_type
=
FLAGS
.
colormap_type
)
if
FLAGS
.
also_save_raw_predictions
:
image_filename
=
image_names
[
i
]
if
train_id_to_eval_id
is
not
None
:
crop_semantic_prediction
=
_convert_train_id_to_eval_id
(
crop_semantic_prediction
,
train_id_to_eval_id
)
save_annotation
.
save_annotation
(
crop_semantic_prediction
,
raw_save_dir
,
image_filename
,
add_colormap
=
False
)
def
main
(
unused_argv
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
# Get dataset-dependent information.
dataset
=
segmentation_dataset
.
get_dataset
(
FLAGS
.
dataset
,
FLAGS
.
vis_split
,
dataset_dir
=
FLAGS
.
dataset_dir
)
train_id_to_eval_id
=
None
if
dataset
.
name
==
segmentation_dataset
.
get_cityscapes_dataset_name
():
tf
.
logging
.
info
(
'Cityscapes requires converting train_id to eval_id.'
)
train_id_to_eval_id
=
_CITYSCAPES_TRAIN_ID_TO_EVAL_ID
# Prepare for visualization.
tf
.
gfile
.
MakeDirs
(
FLAGS
.
vis_logdir
)
save_dir
=
os
.
path
.
join
(
FLAGS
.
vis_logdir
,
_SEMANTIC_PREDICTION_SAVE_FOLDER
)
tf
.
gfile
.
MakeDirs
(
save_dir
)
raw_save_dir
=
os
.
path
.
join
(
FLAGS
.
vis_logdir
,
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER
)
tf
.
gfile
.
MakeDirs
(
raw_save_dir
)
tf
.
logging
.
info
(
'Visualizing on %s set'
,
FLAGS
.
vis_split
)
g
=
tf
.
Graph
()
with
g
.
as_default
():
samples
=
input_generator
.
get
(
dataset
,
FLAGS
.
vis_crop_size
,
FLAGS
.
vis_batch_size
,
min_resize_value
=
FLAGS
.
min_resize_value
,
max_resize_value
=
FLAGS
.
max_resize_value
,
resize_factor
=
FLAGS
.
resize_factor
,
dataset_split
=
FLAGS
.
vis_split
,
is_training
=
False
,
model_variant
=
FLAGS
.
model_variant
)
model_options
=
common
.
ModelOptions
(
outputs_to_num_classes
=
{
common
.
OUTPUT_TYPE
:
dataset
.
num_classes
},
crop_size
=
FLAGS
.
vis_crop_size
,
atrous_rates
=
FLAGS
.
atrous_rates
,
output_stride
=
FLAGS
.
output_stride
)
if
tuple
(
FLAGS
.
eval_scales
)
==
(
1.0
,):
tf
.
logging
.
info
(
'Performing single-scale test.'
)
predictions
=
model
.
predict_labels
(
samples
[
common
.
IMAGE
],
model_options
=
model_options
,
image_pyramid
=
FLAGS
.
image_pyramid
)
else
:
tf
.
logging
.
info
(
'Performing multi-scale test.'
)
predictions
=
model
.
predict_labels_multi_scale
(
samples
[
common
.
IMAGE
],
model_options
=
model_options
,
eval_scales
=
FLAGS
.
eval_scales
,
add_flipped_images
=
FLAGS
.
add_flipped_images
)
predictions
=
predictions
[
common
.
OUTPUT_TYPE
]
if
FLAGS
.
min_resize_value
and
FLAGS
.
max_resize_value
:
# Only support batch_size = 1, since we assume the dimensions of original
# image after tf.squeeze is [height, width, 3].
assert
FLAGS
.
vis_batch_size
==
1
# Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back.
original_image
=
tf
.
squeeze
(
samples
[
common
.
ORIGINAL_IMAGE
])
original_image_shape
=
tf
.
shape
(
original_image
)
predictions
=
tf
.
slice
(
predictions
,
[
0
,
0
,
0
],
[
1
,
original_image_shape
[
0
],
original_image_shape
[
1
]])
resized_shape
=
tf
.
to_int32
([
tf
.
squeeze
(
samples
[
common
.
HEIGHT
]),
tf
.
squeeze
(
samples
[
common
.
WIDTH
])])
predictions
=
tf
.
squeeze
(
tf
.
image
.
resize_images
(
tf
.
expand_dims
(
predictions
,
3
),
resized_shape
,
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
,
align_corners
=
True
),
3
)
tf
.
train
.
get_or_create_global_step
()
saver
=
tf
.
train
.
Saver
(
slim
.
get_variables_to_restore
())
sv
=
tf
.
train
.
Supervisor
(
graph
=
g
,
logdir
=
FLAGS
.
vis_logdir
,
init_op
=
tf
.
global_variables_initializer
(),
summary_op
=
None
,
summary_writer
=
None
,
global_step
=
None
,
saver
=
saver
)
num_batches
=
int
(
math
.
ceil
(
dataset
.
num_samples
/
float
(
FLAGS
.
vis_batch_size
)))
last_checkpoint
=
None
# Loop to visualize the results when new checkpoint is created.
num_iters
=
0
while
(
FLAGS
.
max_number_of_iterations
<=
0
or
num_iters
<
FLAGS
.
max_number_of_iterations
):
num_iters
+=
1
last_checkpoint
=
slim
.
evaluation
.
wait_for_new_checkpoint
(
FLAGS
.
checkpoint_dir
,
last_checkpoint
)
start
=
time
.
time
()
tf
.
logging
.
info
(
'Starting visualization at '
+
time
.
strftime
(
'%Y-%m-%d-%H:%M:%S'
,
time
.
gmtime
()))
tf
.
logging
.
info
(
'Visualizing with model %s'
,
last_checkpoint
)
with
sv
.
managed_session
(
FLAGS
.
master
,
start_standard_services
=
False
)
as
sess
:
sv
.
start_queue_runners
(
sess
)
sv
.
saver
.
restore
(
sess
,
last_checkpoint
)
image_id_offset
=
0
for
batch
in
range
(
num_batches
):
tf
.
logging
.
info
(
'Visualizing batch %d / %d'
,
batch
+
1
,
num_batches
)
_process_batch
(
sess
=
sess
,
original_images
=
samples
[
common
.
ORIGINAL_IMAGE
],
semantic_predictions
=
predictions
,
image_names
=
samples
[
common
.
IMAGE_NAME
],
image_heights
=
samples
[
common
.
HEIGHT
],
image_widths
=
samples
[
common
.
WIDTH
],
image_id_offset
=
image_id_offset
,
save_dir
=
save_dir
,
raw_save_dir
=
raw_save_dir
,
train_id_to_eval_id
=
train_id_to_eval_id
)
image_id_offset
+=
FLAGS
.
vis_batch_size
tf
.
logging
.
info
(
'Finished visualization at '
+
time
.
strftime
(
'%Y-%m-%d-%H:%M:%S'
,
time
.
gmtime
()))
time_to_next_eval
=
start
+
FLAGS
.
eval_interval_secs
-
time
.
time
()
if
time_to_next_eval
>
0
:
time
.
sleep
(
time_to_next_eval
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'checkpoint_dir'
)
flags
.
mark_flag_as_required
(
'vis_logdir'
)
flags
.
mark_flag_as_required
(
'dataset_dir'
)
tf
.
app
.
run
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment