Commit 2c962110 authored by huihui-personal's avatar huihui-personal Committed by aquariusjay
Browse files

Open source deeplab changes. Check change log in README.md for detailed info. (#6231)

* \nRefactor deeplab to use MonitoredTrainingSession\n

PiperOrigin-RevId: 234237190

* Update export_model.py

* Update nas_cell.py

* Update nas_network.py

* Update train.py

* Update deeplab_demo.ipynb

* Update nas_cell.py
parent a432998c
...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_ ...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_
image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
image dimension is all equal to 1024x2048. image dimension is all equal to 1024x2048.
___ ___
Q9: Why multi-gpu training is slow?
A: Please try to use more threads to pre-process the inputs. For, example change [num_readers = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L71) and [num_threads = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L72).
___
## References ## References
......
...@@ -32,13 +32,13 @@ sudo pip install matplotlib ...@@ -32,13 +32,13 @@ sudo pip install matplotlib
## Add Libraries to PYTHONPATH ## Add Libraries to PYTHONPATH
When running locally, the tensorflow/models/research/ and slim directories When running locally, the tensorflow/models/research/ directory should be
should be appended to PYTHONPATH. This can be done by running the following from appended to PYTHONPATH. This can be done by running the following from
tensorflow/models/research/: tensorflow/models/research/:
```bash ```bash
# From tensorflow/models/research/ # From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim export PYTHONPATH=$PYTHONPATH:`pwd`
``` ```
Note: This command needs to run from every new terminal you start. If you wish Note: This command needs to run from every new terminal you start. If you wish
......
...@@ -88,13 +88,18 @@ We provide some checkpoints that have been pretrained on ADE20K training set. ...@@ -88,13 +88,18 @@ We provide some checkpoints that have been pretrained on ADE20K training set.
Note that the model has only been pretrained on ImageNet, following the Note that the model has only been pretrained on ImageNet, following the
dataset rule. dataset rule.
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder | Input size
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: ------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: | :-----:
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 mobilenetv2_ade20k_train | MobileNet-v2 | ImageNet <br> ADE20K training set | N/A | OS = 4 | 257x257
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 | 513x513
The input dimensions of ADE20K have a huge amount of variation. We resize inputs so that the longest size is 257 for MobileNet-v2 (faster inference) and 513 for Xception_65 (better performation). Note that we also include the decoder module in the MobileNet-v2 checkpoint.
Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------: ------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------:
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB [mobilenetv2_ade20k_train](http://download.tensorflow.org/models/deeplabv3_mnv2_ade20k_train_2018_12_03.tar.gz) | 16 | [1.0] | No | 32.04% (val) | 75.41% (val) | 24.8MB
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
## Checkpoints pretrained on ImageNet ## Checkpoints pretrained on ImageNet
......
...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image, ...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image,
label = tf.cast(label, tf.int32) label = tf.cast(label, tf.int32)
# Resize image and label to the desired range. # Resize image and label to the desired range.
if min_resize_value is not None or max_resize_value is not None: if min_resize_value or max_resize_value:
[processed_image, label] = ( [processed_image, label] = (
preprocess_utils.resize_to_range( preprocess_utils.resize_to_range(
image=processed_image, image=processed_image,
......
This diff is collapsed.
...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase):
add_image_level_feature=True, add_image_level_feature=True,
aspp_with_batch_norm=True, aspp_with_batch_norm=True,
logits_kernel_size=1, logits_kernel_size=1,
decoder_output_stride=[4],
model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test. model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test.
g = tf.Graph() g = tf.Graph()
...@@ -116,16 +117,16 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -116,16 +117,16 @@ class DeeplabModelTest(tf.test.TestCase):
outputs_to_num_classes = {'semantic': 2} outputs_to_num_classes = {'semantic': 2}
expected_endpoints = ['merged_logits'] expected_endpoints = ['merged_logits']
dense_prediction_cell_config = [ dense_prediction_cell_config = [
{'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1}, {'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1},
{'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0}, {'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0},
] ]
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes, outputs_to_num_classes,
crop_size, crop_size,
output_stride=16)._replace( output_stride=16)._replace(
aspp_with_batch_norm=True, aspp_with_batch_norm=True,
model_variant='mobilenet_v2', model_variant='mobilenet_v2',
dense_prediction_cell_config=dense_prediction_cell_config) dense_prediction_cell_config=dense_prediction_cell_config)
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
with self.test_session(graph=g): with self.test_session(graph=g):
...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase):
image_pyramid=[1.0]) image_pyramid=[1.0])
for output in outputs_to_num_classes: for output in outputs_to_num_classes:
scales_to_model_results = outputs_to_scales_to_model_results[output] scales_to_model_results = outputs_to_scales_to_model_results[output]
self.assertListEqual(scales_to_model_results.keys(), self.assertListEqual(
expected_endpoints) list(scales_to_model_results), expected_endpoints)
self.assertEqual(len(scales_to_model_results), 1) self.assertEqual(len(scales_to_model_results), 1)
......
This directory contains testing data.
# pascal_voc_seg
This folder contains data specific to pascal_voc_seg dataset. val-00000-of-00001.tfrecord contains
three randomly generated images with format defined in
tensorflow/models/research/deeplab/datasets/build_voc2012_data.py.
This diff is collapsed.
...@@ -37,7 +37,7 @@ _PASCAL = 'pascal' ...@@ -37,7 +37,7 @@ _PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset. # Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = { _DATASET_MAX_ENTRIES = {
_ADE20K: 151, _ADE20K: 151,
_CITYSCAPES: 19, _CITYSCAPES: 256,
_MAPILLARY_VISTAS: 66, _MAPILLARY_VISTAS: 66,
_PASCAL: 256, _PASCAL: 256,
} }
...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap(): ...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap():
Returns: Returns:
A colormap for visualizing segmentation results. A colormap for visualizing segmentation results.
""" """
return np.asarray([ colormap = np.zeros((256, 3), dtype=np.uint8)
[128, 64, 128], colormap[0] = [128, 64, 128]
[244, 35, 232], colormap[1] = [244, 35, 232]
[70, 70, 70], colormap[2] = [70, 70, 70]
[102, 102, 156], colormap[3] = [102, 102, 156]
[190, 153, 153], colormap[4] = [190, 153, 153]
[153, 153, 153], colormap[5] = [153, 153, 153]
[250, 170, 30], colormap[6] = [250, 170, 30]
[220, 220, 0], colormap[7] = [220, 220, 0]
[107, 142, 35], colormap[8] = [107, 142, 35]
[152, 251, 152], colormap[9] = [152, 251, 152]
[70, 130, 180], colormap[10] = [70, 130, 180]
[220, 20, 60], colormap[11] = [220, 20, 60]
[255, 0, 0], colormap[12] = [255, 0, 0]
[0, 0, 142], colormap[13] = [0, 0, 142]
[0, 0, 70], colormap[14] = [0, 0, 70]
[0, 60, 100], colormap[15] = [0, 60, 100]
[0, 80, 100], colormap[16] = [0, 80, 100]
[0, 0, 230], colormap[17] = [0, 0, 230]
[119, 11, 32], colormap[18] = [119, 11, 32]
]) return colormap
def create_mapillary_vistas_label_colormap(): def create_mapillary_vistas_label_colormap():
...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL): ...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL):
map maximum entry. map maximum entry.
""" """
if label.ndim != 2: if label.ndim != 2:
raise ValueError('Expect 2-D input label') raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]: if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
raise ValueError('label value too large.') raise ValueError(
'label value too large: {} >= {}.'.format(
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
colormap = create_label_colormap(dataset) colormap = create_label_colormap(dataset)
return colormap[label] return colormap[label]
def get_dataset_colormap_max_entries(dataset):
return _DATASET_MAX_ENTRIES[dataset]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
def _get_data(data_provider, dataset_split):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
...@@ -29,16 +29,20 @@ def save_annotation(label, ...@@ -29,16 +29,20 @@ def save_annotation(label,
save_dir, save_dir,
filename, filename,
add_colormap=True, add_colormap=True,
normalize_to_unit_values=False,
scale_values=False,
colormap_type=get_dataset_colormap.get_pascal_name()): colormap_type=get_dataset_colormap.get_pascal_name()):
"""Saves the given label to image on disk. """Saves the given label to image on disk.
Args: Args:
label: The numpy array to be saved. The data will be converted label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image. to uint8 and saved as png image.
save_dir: The directory to which the results will be saved. save_dir: String, the directory to which the results will be saved.
filename: The image filename. filename: String, the image filename.
add_colormap: Add color map to the label or not. add_colormap: Boolean, add color map to the label or not.
colormap_type: Colormap type for visualization. normalize_to_unit_values: Boolean, normalize the input values to [0, 1].
scale_values: Boolean, scale the input values to [0, 255] for visualization.
colormap_type: String, colormap type for visualization.
""" """
# Add colormap for visualizing the prediction. # Add colormap for visualizing the prediction.
if add_colormap: if add_colormap:
...@@ -46,6 +50,15 @@ def save_annotation(label, ...@@ -46,6 +50,15 @@ def save_annotation(label,
label, colormap_type) label, colormap_type)
else: else:
colored_label = label colored_label = label
if normalize_to_unit_values:
min_value = np.amin(colored_label)
max_value = np.amax(colored_label)
range_value = max_value - min_value
if range_value != 0:
colored_label = (colored_label - min_value) / range_value
if scale_values:
colored_label = 255. * colored_label
pil_image = img.fromarray(colored_label.astype(dtype=np.uint8)) pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f: with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
......
...@@ -19,7 +19,11 @@ import six ...@@ -19,7 +19,11 @@ import six
import tensorflow as tf import tensorflow as tf
from deeplab.core import preprocess_utils from deeplab.core import preprocess_utils
slim = tf.contrib.slim
def _div_maybe_zero(total_loss, num_present):
"""Normalizes the total loss with the number of present pixels."""
return tf.to_float(num_present > 0) * tf.div(total_loss,
tf.maximum(1e-5, num_present))
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label, ignore_label,
loss_weight=1.0, loss_weight=1.0,
upsample_logits=True, upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
scope=None): scope=None):
"""Adds softmax cross entropy loss for logits of each scale. """Adds softmax cross entropy loss for logits of each scale.
...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label: Integer, label to ignore. ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight. loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not. upsample_logits: Boolean, upsample logits or not.
hard_example_mining_step: An integer, the training step in which the hard
exampling mining kicks off. Note that we gradually reduce the mining
percent to the top_k_percent_pixels. For example, if
hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
< 1.0, only compute the loss for the top k percent pixels (e.g., the top
20% pixels). This is useful for hard pixel mining.
scope: String, the scope for the loss. scope: String, the scope for the loss.
Raises: Raises:
...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
scaled_labels = tf.reshape(scaled_labels, shape=[-1]) scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels, not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight ignore_label)) * loss_weight
one_hot_labels = slim.one_hot_encoding( one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0) scaled_labels, num_classes, on_value=1.0, off_value=0.0)
tf.losses.softmax_cross_entropy(
one_hot_labels, if top_k_percent_pixels == 1.0:
tf.reshape(logits, shape=[-1, num_classes]), # Compute the loss for all pixels.
weights=not_ignore_mask, tf.losses.softmax_cross_entropy(
scope=loss_scope) one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
num_pixels = tf.to_float(tf.shape(logits)[0])
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
total_loss = tf.reduce_sum(top_k_losses)
num_present = tf.reduce_sum(
tf.to_float(tf.not_equal(top_k_losses, 0.0)))
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
def get_model_init_fn(train_logdir, def get_model_init_fn(train_logdir,
...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir, ...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir,
if not initialize_last_layer: if not initialize_last_layer:
exclude_list.extend(last_layers) exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list) variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore: if variables_to_restore:
return slim.assign_from_checkpoint_fn( init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
tf_initial_checkpoint, tf_initial_checkpoint,
variables_to_restore, variables_to_restore,
ignore_missing_vars=ignore_missing_vars) ignore_missing_vars=ignore_missing_vars)
global_step = tf.train.get_or_create_global_step()
def restore_fn(unused_scaffold, sess):
sess.run(init_op, init_feed_dict)
sess.run([global_step])
return restore_fn
return None return None
...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
""" """
gradient_multipliers = {} gradient_multipliers = {}
for var in slim.get_model_variables(): for var in tf.model_variables():
# Double the learning rate for biases. # Double the learning rate for biases.
if 'biases' in var.op.name: if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2. gradient_multipliers[var.op.name] = 2.
...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
return gradient_multipliers return gradient_multipliers
def get_model_learning_rate( def get_model_learning_rate(learning_policy,
learning_policy, base_learning_rate, learning_rate_decay_step, base_learning_rate,
learning_rate_decay_factor, training_number_of_steps, learning_power, learning_rate_decay_step,
slow_start_step, slow_start_learning_rate): learning_rate_decay_factor,
training_number_of_steps,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
"""Gets model's learning rate. """Gets model's learning rate.
Computes the model's learning rate for different learning policy. Computes the model's learning rate for different learning policy.
...@@ -181,31 +245,51 @@ def get_model_learning_rate( ...@@ -181,31 +245,51 @@ def get_model_learning_rate(
slow_start_step: Training model with small learning rate for the first slow_start_step: Training model with small learning rate for the first
few steps. few steps.
slow_start_learning_rate: The learning rate employed during slow start. slow_start_learning_rate: The learning rate employed during slow start.
slow_start_burnin_type: The burnin type for the slow start stage. Can be
`none` which means no burnin or `linear` which means the learning rate
increases linearly from slow_start_learning_rate and reaches
base_learning_rate after slow_start_steps.
Returns: Returns:
Learning rate for the specified learning policy. Learning rate for the specified learning policy.
Raises: Raises:
ValueError: If learning policy is not recognized. ValueError: If learning policy or slow start burnin type is not recognized.
""" """
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
if learning_policy == 'step': if learning_policy == 'step':
learning_rate = tf.train.exponential_decay( learning_rate = tf.train.exponential_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
learning_rate_decay_step, learning_rate_decay_step,
learning_rate_decay_factor, learning_rate_decay_factor,
staircase=True) staircase=True)
elif learning_policy == 'poly': elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay( learning_rate = tf.train.polynomial_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
training_number_of_steps, training_number_of_steps,
end_learning_rate=0, end_learning_rate=0,
power=learning_power) power=learning_power)
else: else:
raise ValueError('Unknown learning policy.') raise ValueError('Unknown learning policy.')
adjusted_slow_start_learning_rate = slow_start_learning_rate
if slow_start_burnin_type == 'linear':
# Do linear burnin. Increase linearly from slow_start_learning_rate and
# reach base_learning_rate after (global_step >= slow_start_steps).
adjusted_slow_start_learning_rate = (
slow_start_learning_rate +
(base_learning_rate - slow_start_learning_rate) *
tf.to_float(global_step) / slow_start_step)
elif slow_start_burnin_type != 'none':
raise ValueError('Unknown burnin type.')
# Employ small learning rate at the first few steps for warm start. # Employ small learning rate at the first few steps for warm start.
return tf.where(global_step < slow_start_step, slow_start_learning_rate, return tf.where(global_step < slow_start_step,
learning_rate) adjusted_slow_start_learning_rate, learning_rate)
...@@ -17,19 +17,15 @@ ...@@ -17,19 +17,15 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import math
import os.path import os.path
import time import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import segmentation_dataset from deeplab.datasets import data_generator
from deeplab.utils import input_generator
from deeplab.utils import save_annotation from deeplab.utils import save_annotation
slim = tf.contrib.slim
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names, ...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names,
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information. # Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = data_generator.Dataset(
FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir) dataset_name=FLAGS.dataset,
split_name=FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.vis_batch_size,
crop_size=FLAGS.vis_crop_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
model_variant=FLAGS.model_variant,
is_training=False,
should_shuffle=False,
should_repeat=False)
train_id_to_eval_id = None train_id_to_eval_id = None
if dataset.name == segmentation_dataset.get_cityscapes_dataset_name(): if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
tf.logging.info('Cityscapes requires converting train_id to eval_id.') tf.logging.info('Cityscapes requires converting train_id to eval_id.')
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
...@@ -204,20 +213,11 @@ def main(unused_argv): ...@@ -204,20 +213,11 @@ def main(unused_argv):
tf.logging.info('Visualizing on %s set', FLAGS.vis_split) tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph() with tf.Graph().as_default():
with g.as_default(): samples = dataset.get_one_shot_iterator().get_next()
samples = input_generator.get(dataset,
FLAGS.vis_crop_size,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.vis_crop_size, crop_size=FLAGS.vis_crop_size,
atrous_rates=FLAGS.atrous_rates, atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride) output_stride=FLAGS.output_stride)
...@@ -244,7 +244,7 @@ def main(unused_argv): ...@@ -244,7 +244,7 @@ def main(unused_argv):
# Reverse the resizing and padding operations performed in preprocessing. # Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then # First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back. # we resize the predictions back.
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
original_image_shape = tf.shape(original_image) original_image_shape = tf.shape(original_image)
predictions = tf.slice( predictions = tf.slice(
...@@ -259,40 +259,35 @@ def main(unused_argv): ...@@ -259,40 +259,35 @@ def main(unused_argv):
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True), 3) align_corners=True), 3)
tf.train.get_or_create_global_step() num_iteration = 0
saver = tf.train.Saver(slim.get_variables_to_restore()) max_num_iteration = FLAGS.max_number_of_iterations
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir, checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
init_op=tf.global_variables_initializer(), FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
summary_op=None, for checkpoint_path in checkpoints_iterator:
summary_writer=None, if max_num_iteration > 0 and num_iteration > max_num_iteration:
global_step=None, break
saver=saver) num_iteration += 1
num_batches = int(math.ceil(
dataset.num_samples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Loop to visualize the results when new checkpoint is created.
num_iters = 0
while (FLAGS.max_number_of_iterations <= 0 or
num_iters < FLAGS.max_number_of_iterations):
num_iters += 1
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info( tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint) tf.logging.info('Visualizing with model %s', checkpoint_path)
with sv.managed_session(FLAGS.master, tf.train.get_or_create_global_step()
start_standard_services=False) as sess:
sv.start_queue_runners(sess) scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
sv.saver.restore(sess, last_checkpoint) session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
master=FLAGS.master,
checkpoint_filename_with_path=checkpoint_path)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=None) as sess:
batch = 0
image_id_offset = 0 image_id_offset = 0
for batch in range(num_batches):
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) while not sess.should_stop():
tf.logging.info('Visualizing batch %d', batch + 1)
_process_batch(sess=sess, _process_batch(sess=sess,
original_images=samples[common.ORIGINAL_IMAGE], original_images=samples[common.ORIGINAL_IMAGE],
semantic_predictions=predictions, semantic_predictions=predictions,
...@@ -304,14 +299,11 @@ def main(unused_argv): ...@@ -304,14 +299,11 @@ def main(unused_argv):
raw_save_dir=raw_save_dir, raw_save_dir=raw_save_dir,
train_id_to_eval_id=train_id_to_eval_id) train_id_to_eval_id=train_id_to_eval_id)
image_id_offset += FLAGS.vis_batch_size image_id_offset += FLAGS.vis_batch_size
batch += 1
tf.logging.info( tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir') flags.mark_flag_as_required('checkpoint_dir')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment