Commit 2c962110 authored by huihui-personal's avatar huihui-personal Committed by aquariusjay
Browse files

Open source deeplab changes. Check change log in README.md for detailed info. (#6231)

* \nRefactor deeplab to use MonitoredTrainingSession\n

PiperOrigin-RevId: 234237190

* Update export_model.py

* Update nas_cell.py

* Update nas_network.py

* Update train.py

* Update deeplab_demo.ipynb

* Update nas_cell.py
parent a432998c
...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_ ...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_
image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
image dimension is all equal to 1024x2048. image dimension is all equal to 1024x2048.
___ ___
Q9: Why multi-gpu training is slow?
A: Please try to use more threads to pre-process the inputs. For, example change [num_readers = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L71) and [num_threads = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L72).
___
## References ## References
......
...@@ -32,13 +32,13 @@ sudo pip install matplotlib ...@@ -32,13 +32,13 @@ sudo pip install matplotlib
## Add Libraries to PYTHONPATH ## Add Libraries to PYTHONPATH
When running locally, the tensorflow/models/research/ and slim directories When running locally, the tensorflow/models/research/ directory should be
should be appended to PYTHONPATH. This can be done by running the following from appended to PYTHONPATH. This can be done by running the following from
tensorflow/models/research/: tensorflow/models/research/:
```bash ```bash
# From tensorflow/models/research/ # From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim export PYTHONPATH=$PYTHONPATH:`pwd`
``` ```
Note: This command needs to run from every new terminal you start. If you wish Note: This command needs to run from every new terminal you start. If you wish
......
...@@ -88,14 +88,19 @@ We provide some checkpoints that have been pretrained on ADE20K training set. ...@@ -88,14 +88,19 @@ We provide some checkpoints that have been pretrained on ADE20K training set.
Note that the model has only been pretrained on ImageNet, following the Note that the model has only been pretrained on ImageNet, following the
dataset rule. dataset rule.
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder | Input size
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: ------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: | :-----:
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 mobilenetv2_ade20k_train | MobileNet-v2 | ImageNet <br> ADE20K training set | N/A | OS = 4 | 257x257
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 | 513x513
The input dimensions of ADE20K have a huge amount of variation. We resize inputs so that the longest size is 257 for MobileNet-v2 (faster inference) and 513 for Xception_65 (better performation). Note that we also include the decoder module in the MobileNet-v2 checkpoint.
Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------: ------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------:
[mobilenetv2_ade20k_train](http://download.tensorflow.org/models/deeplabv3_mnv2_ade20k_train_2018_12_03.tar.gz) | 16 | [1.0] | No | 32.04% (val) | 75.41% (val) | 24.8MB
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB [xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
## Checkpoints pretrained on ImageNet ## Checkpoints pretrained on ImageNet
Un-tar'ed directory includes: Un-tar'ed directory includes:
......
...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image, ...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image,
label = tf.cast(label, tf.int32) label = tf.cast(label, tf.int32)
# Resize image and label to the desired range. # Resize image and label to the desired range.
if min_resize_value is not None or max_resize_value is not None: if min_resize_value or max_resize_value:
[processed_image, label] = ( [processed_image, label] = (
preprocess_utils.resize_to_range( preprocess_utils.resize_to_range(
image=processed_image, image=processed_image,
......
...@@ -56,7 +56,6 @@ from deeplab.core import dense_prediction_cell ...@@ -56,7 +56,6 @@ from deeplab.core import dense_prediction_cell
from deeplab.core import feature_extractor from deeplab.core import feature_extractor
from deeplab.core import utils from deeplab.core import utils
slim = tf.contrib.slim slim = tf.contrib.slim
LOGITS_SCOPE_NAME = 'logits' LOGITS_SCOPE_NAME = 'logits'
...@@ -67,9 +66,11 @@ CONCAT_PROJECTION_SCOPE = 'concat_projection' ...@@ -67,9 +66,11 @@ CONCAT_PROJECTION_SCOPE = 'concat_projection'
DECODER_SCOPE = 'decoder' DECODER_SCOPE = 'decoder'
META_ARCHITECTURE_SCOPE = 'meta_architecture' META_ARCHITECTURE_SCOPE = 'meta_architecture'
_resize_bilinear = utils.resize_bilinear
scale_dimension = utils.scale_dimension scale_dimension = utils.scale_dimension
split_separable_conv2d = utils.split_separable_conv2d split_separable_conv2d = utils.split_separable_conv2d
def get_extra_layer_scopes(last_layers_contain_logits_only=False): def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers. """Gets the scopes for extra layers.
...@@ -135,20 +136,20 @@ def predict_labels_multi_scale(images, ...@@ -135,20 +136,20 @@ def predict_labels_multi_scale(images,
for output in sorted(outputs_to_scales_to_logits): for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output] scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear( logits = _resize_bilinear(
scales_to_logits[MERGED_LOGITS_SCOPE], scales_to_logits[MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3], tf.shape(images)[1:3],
align_corners=True) scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
outputs_to_predictions[output].append( outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits), 4)) tf.expand_dims(tf.nn.softmax(logits), 4))
if add_flipped_images: if add_flipped_images:
scales_to_logits_reversed = ( scales_to_logits_reversed = (
outputs_to_scales_to_logits_reversed[output]) outputs_to_scales_to_logits_reversed[output])
logits_reversed = tf.image.resize_bilinear( logits_reversed = _resize_bilinear(
tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]), tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]),
tf.shape(images)[1:3], tf.shape(images)[1:3],
align_corners=True) scales_to_logits_reversed[MERGED_LOGITS_SCOPE].dtype)
outputs_to_predictions[output].append( outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits_reversed), 4)) tf.expand_dims(tf.nn.softmax(logits_reversed), 4))
...@@ -184,37 +185,35 @@ def predict_labels(images, model_options, image_pyramid=None): ...@@ -184,37 +185,35 @@ def predict_labels(images, model_options, image_pyramid=None):
predictions = {} predictions = {}
for output in sorted(outputs_to_scales_to_logits): for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output] scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear( logits = scales_to_logits[MERGED_LOGITS_SCOPE]
scales_to_logits[MERGED_LOGITS_SCOPE], # There are two ways to obtain the final prediction results: (1) bilinear
# upsampling the logits followed by argmax, or (2) argmax followed by
# nearest neighbor upsampling. The second option may introduce the "blocking
# effect" but is computationally efficient.
if model_options.prediction_with_upsampled_logits:
logits = _resize_bilinear(logits,
tf.shape(images)[1:3], tf.shape(images)[1:3],
align_corners=True) scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
predictions[output] = tf.argmax(logits, 3) predictions[output] = tf.argmax(logits, 3)
else:
argmax_results = tf.argmax(logits, 3)
argmax_results = tf.image.resize_nearest_neighbor(
tf.expand_dims(argmax_results, 3),
tf.shape(images)[1:3],
align_corners=True,
name='resize_prediction')
predictions[output] = tf.squeeze(argmax_results, 3)
return predictions return predictions
def _resize_bilinear(images, size, output_dtype=tf.float32):
"""Returns resized images as output_type.
Args:
images: A tensor of size [batch, height_in, width_in, channels].
size: A 1-D int32 Tensor of 2 elements: new_height, new_width. The new size
for the images.
output_dtype: The destination type.
Returns:
A tensor of size [batch, height_out, width_out, channels] as a dtype of
output_dtype.
"""
images = tf.image.resize_bilinear(images, size, align_corners=True)
return tf.cast(images, dtype=output_dtype)
def multi_scale_logits(images, def multi_scale_logits(images,
model_options, model_options,
image_pyramid, image_pyramid,
weight_decay=0.0001, weight_decay=0.0001,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Gets the logits for multi-scale inputs. """Gets the logits for multi-scale inputs.
The returned logits are all downsampled (due to max-pooling layers) The returned logits are all downsampled (due to max-pooling layers)
...@@ -227,6 +226,12 @@ def multi_scale_logits(images, ...@@ -227,6 +226,12 @@ def multi_scale_logits(images,
weight_decay: The weight decay for model variables. weight_decay: The weight decay for model variables.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
outputs_to_scales_to_logits: A map of maps from output_type (e.g., outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
...@@ -250,10 +255,15 @@ def multi_scale_logits(images, ...@@ -250,10 +255,15 @@ def multi_scale_logits(images,
crop_width = ( crop_width = (
model_options.crop_size[1] model_options.crop_size[1]
if model_options.crop_size else tf.shape(images)[2]) if model_options.crop_size else tf.shape(images)[2])
if model_options.image_pooling_crop_size:
image_pooling_crop_height = model_options.image_pooling_crop_size[0]
image_pooling_crop_width = model_options.image_pooling_crop_size[1]
# Compute the height, width for the output logits. # Compute the height, width for the output logits.
logits_output_stride = ( if model_options.decoder_output_stride:
model_options.decoder_output_stride or model_options.output_stride) logits_output_stride = min(model_options.decoder_output_stride)
else:
logits_output_stride = model_options.output_stride
logits_height = scale_dimension( logits_height = scale_dimension(
crop_height, crop_height,
...@@ -268,33 +278,45 @@ def multi_scale_logits(images, ...@@ -268,33 +278,45 @@ def multi_scale_logits(images,
for k in model_options.outputs_to_num_classes for k in model_options.outputs_to_num_classes
} }
num_channels = images.get_shape().as_list()[-1]
for image_scale in image_pyramid: for image_scale in image_pyramid:
if image_scale != 1.0: if image_scale != 1.0:
scaled_height = scale_dimension(crop_height, image_scale) scaled_height = scale_dimension(crop_height, image_scale)
scaled_width = scale_dimension(crop_width, image_scale) scaled_width = scale_dimension(crop_width, image_scale)
scaled_crop_size = [scaled_height, scaled_width] scaled_crop_size = [scaled_height, scaled_width]
scaled_images = tf.image.resize_bilinear( scaled_images = _resize_bilinear(images, scaled_crop_size, images.dtype)
images, scaled_crop_size, align_corners=True)
if model_options.crop_size: if model_options.crop_size:
scaled_images.set_shape([None, scaled_height, scaled_width, 3]) scaled_images.set_shape(
[None, scaled_height, scaled_width, num_channels])
# Adjust image_pooling_crop_size accordingly.
scaled_image_pooling_crop_size = None
if model_options.image_pooling_crop_size:
scaled_image_pooling_crop_size = [
scale_dimension(image_pooling_crop_height, image_scale),
scale_dimension(image_pooling_crop_width, image_scale)]
else: else:
scaled_crop_size = model_options.crop_size scaled_crop_size = model_options.crop_size
scaled_images = images scaled_images = images
scaled_image_pooling_crop_size = model_options.image_pooling_crop_size
updated_options = model_options._replace(crop_size=scaled_crop_size) updated_options = model_options._replace(
crop_size=scaled_crop_size,
image_pooling_crop_size=scaled_image_pooling_crop_size)
outputs_to_logits = _get_logits( outputs_to_logits = _get_logits(
scaled_images, scaled_images,
updated_options, updated_options,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=tf.AUTO_REUSE, reuse=tf.AUTO_REUSE,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
nas_training_hyper_parameters=nas_training_hyper_parameters)
# Resize the logits to have the same dimension before merging. # Resize the logits to have the same dimension before merging.
for output in sorted(outputs_to_logits): for output in sorted(outputs_to_logits):
outputs_to_logits[output] = tf.image.resize_bilinear( outputs_to_logits[output] = _resize_bilinear(
outputs_to_logits[output], [logits_height, logits_width], outputs_to_logits[output], [logits_height, logits_width],
align_corners=True) outputs_to_logits[output].dtype)
# Return when only one input scale. # Return when only one input scale.
if len(image_pyramid) == 1: if len(image_pyramid) == 1:
...@@ -330,7 +352,8 @@ def extract_features(images, ...@@ -330,7 +352,8 @@ def extract_features(images,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Extracts features by the particular model_variant. """Extracts features by the particular model_variant.
Args: Args:
...@@ -340,6 +363,12 @@ def extract_features(images, ...@@ -340,6 +363,12 @@ def extract_features(images,
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
concat_logits: A tensor of size [batch, feature_height, feature_width, concat_logits: A tensor of size [batch, feature_height, feature_width,
...@@ -354,10 +383,16 @@ def extract_features(images, ...@@ -354,10 +383,16 @@ def extract_features(images,
multi_grid=model_options.multi_grid, multi_grid=model_options.multi_grid,
model_variant=model_options.model_variant, model_variant=model_options.model_variant,
depth_multiplier=model_options.depth_multiplier, depth_multiplier=model_options.depth_multiplier,
divisible_by=model_options.divisible_by,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) preprocessed_images_dtype=model_options.preprocessed_images_dtype,
fine_tune_batch_norm=fine_tune_batch_norm,
nas_stem_output_num_conv_filters=(
model_options.nas_stem_output_num_conv_filters),
nas_training_hyper_parameters=nas_training_hyper_parameters,
use_bounded_activation=model_options.use_bounded_activation)
if not model_options.aspp_with_batch_norm: if not model_options.aspp_with_batch_norm:
return features, end_points return features, end_points
...@@ -380,10 +415,10 @@ def extract_features(images, ...@@ -380,10 +415,10 @@ def extract_features(images,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm)
return concat_logits, end_points return concat_logits, end_points
else: else:
# The following codes employ the DeepLabv3 ASPP module. Note that We # The following codes employ the DeepLabv3 ASPP module. Note that we
# could express the ASPP module as one particular dense prediction # could express the ASPP module as one particular dense prediction
# cell architecture. We do not do so but leave the following codes in # cell architecture. We do not do so but leave the following codes
# order for backward compatibility. # for backward compatibility.
batch_norm_params = { batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm, 'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997, 'decay': 0.9997,
...@@ -391,10 +426,13 @@ def extract_features(images, ...@@ -391,10 +426,13 @@ def extract_features(images,
'scale': True, 'scale': True,
} }
activation_fn = (
tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu)
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], [slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu, activation_fn=activation_fn,
normalizer_fn=slim.batch_norm, normalizer_fn=slim.batch_norm,
padding='SAME', padding='SAME',
stride=1, stride=1,
...@@ -416,7 +454,8 @@ def extract_features(images, ...@@ -416,7 +454,8 @@ def extract_features(images,
image_pooling_crop_size[1], image_pooling_crop_size[1],
1. / model_options.output_stride) 1. / model_options.output_stride)
image_feature = slim.avg_pool2d( image_feature = slim.avg_pool2d(
features, [pool_height, pool_width], [1, 1], padding='VALID') features, [pool_height, pool_width],
model_options.image_pooling_stride, padding='VALID')
resize_height = scale_dimension( resize_height = scale_dimension(
model_options.crop_size[0], model_options.crop_size[0],
1. / model_options.output_stride) 1. / model_options.output_stride)
...@@ -483,7 +522,8 @@ def _get_logits(images, ...@@ -483,7 +522,8 @@ def _get_logits(images,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Gets the logits by atrous/image spatial pyramid pooling. """Gets the logits by atrous/image spatial pyramid pooling.
Args: Args:
...@@ -493,6 +533,12 @@ def _get_logits(images, ...@@ -493,6 +533,12 @@ def _get_logits(images,
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
outputs_to_logits: A map from output_type to logits. outputs_to_logits: A map from output_type to logits.
...@@ -503,29 +549,22 @@ def _get_logits(images, ...@@ -503,29 +549,22 @@ def _get_logits(images,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
nas_training_hyper_parameters=nas_training_hyper_parameters)
if model_options.decoder_output_stride is not None: if model_options.decoder_output_stride is not None:
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = scale_dimension(height,
1.0 / model_options.decoder_output_stride)
decoder_width = scale_dimension(width,
1.0 / model_options.decoder_output_stride)
features = refine_by_decoder( features = refine_by_decoder(
features, features,
end_points, end_points,
decoder_height=decoder_height, crop_size=model_options.crop_size,
decoder_width=decoder_width, decoder_output_stride=model_options.decoder_output_stride,
decoder_use_separable_conv=model_options.decoder_use_separable_conv, decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant, model_variant=model_options.model_variant,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
use_bounded_activation=model_options.use_bounded_activation)
outputs_to_logits = {} outputs_to_logits = {}
for output in sorted(model_options.outputs_to_num_classes): for output in sorted(model_options.outputs_to_num_classes):
...@@ -544,14 +583,15 @@ def _get_logits(images, ...@@ -544,14 +583,15 @@ def _get_logits(images,
def refine_by_decoder(features, def refine_by_decoder(features,
end_points, end_points,
decoder_height, crop_size=None,
decoder_width, decoder_output_stride=None,
decoder_use_separable_conv=False, decoder_use_separable_conv=False,
model_variant=None, model_variant=None,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
use_bounded_activation=False):
"""Adds the decoder to obtain sharper segmentation results. """Adds the decoder to obtain sharper segmentation results.
Args: Args:
...@@ -559,19 +599,28 @@ def refine_by_decoder(features, ...@@ -559,19 +599,28 @@ def refine_by_decoder(features,
features_channels]. features_channels].
end_points: A dictionary from components of the network to the corresponding end_points: A dictionary from components of the network to the corresponding
activation. activation.
decoder_height: The height of decoder feature maps. crop_size: A tuple [crop_height, crop_width] specifying whole patch crop
decoder_width: The width of decoder feature maps. size.
decoder_output_stride: A list of integers specifying the output stride of
low-level features used in the decoder module.
decoder_use_separable_conv: Employ separable convolution for decoder or not. decoder_use_separable_conv: Employ separable convolution for decoder or not.
model_variant: Model variant for feature extraction. model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables. weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns: Returns:
Decoder output with size [batch, decoder_height, decoder_width, Decoder output with size [batch, decoder_height, decoder_width,
decoder_channels]. decoder_channels].
Raises:
ValueError: If crop_size is None.
""" """
if crop_size is None:
raise ValueError('crop_size must be provided when using decoder.')
batch_norm_params = { batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm, 'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997, 'decay': 0.9997,
...@@ -582,25 +631,28 @@ def refine_by_decoder(features, ...@@ -582,25 +631,28 @@ def refine_by_decoder(features,
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], [slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu6 if use_bounded_activation else tf.nn.relu,
normalizer_fn=slim.batch_norm, normalizer_fn=slim.batch_norm,
padding='SAME', padding='SAME',
stride=1, stride=1,
reuse=reuse): reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]): with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]):
feature_list = feature_extractor.networks_to_feature_maps[
model_variant][feature_extractor.DECODER_END_POINTS]
if feature_list is None:
tf.logging.info('Not found any decoder end points.')
return features
else:
decoder_features = features decoder_features = features
decoder_stage = 0
scope_suffix = ''
for output_stride in decoder_output_stride:
feature_list = feature_extractor.networks_to_feature_maps[
model_variant][
feature_extractor.DECODER_END_POINTS][output_stride]
# If only one decoder stage, we do not change the scope name in
# order for backward compactibility.
if decoder_stage:
scope_suffix = '_{}'.format(decoder_stage)
for i, name in enumerate(feature_list): for i, name in enumerate(feature_list):
decoder_features_list = [decoder_features] decoder_features_list = [decoder_features]
# MobileNet and NAS variants use different naming convention.
# MobileNet variants use different naming convention. if 'mobilenet' in model_variant or model_variant.startswith('nas'):
if 'mobilenet' in model_variant:
feature_name = name feature_name = name
else: else:
feature_name = '{}/{}'.format( feature_name = '{}/{}'.format(
...@@ -610,11 +662,14 @@ def refine_by_decoder(features, ...@@ -610,11 +662,14 @@ def refine_by_decoder(features,
end_points[feature_name], end_points[feature_name],
48, 48,
1, 1,
scope='feature_projection' + str(i))) scope='feature_projection' + str(i) + scope_suffix))
# Determine the output size.
decoder_height = scale_dimension(crop_size[0], 1.0 / output_stride)
decoder_width = scale_dimension(crop_size[1], 1.0 / output_stride)
# Resize to decoder_height/decoder_width. # Resize to decoder_height/decoder_width.
for j, feature in enumerate(decoder_features_list): for j, feature in enumerate(decoder_features_list):
decoder_features_list[j] = tf.image.resize_bilinear( decoder_features_list[j] = _resize_bilinear(
feature, [decoder_height, decoder_width], align_corners=True) feature, [decoder_height, decoder_width], feature.dtype)
h = (None if isinstance(decoder_height, tf.Tensor) h = (None if isinstance(decoder_height, tf.Tensor)
else decoder_height) else decoder_height)
w = (None if isinstance(decoder_width, tf.Tensor) w = (None if isinstance(decoder_width, tf.Tensor)
...@@ -627,13 +682,13 @@ def refine_by_decoder(features, ...@@ -627,13 +682,13 @@ def refine_by_decoder(features,
filters=decoder_depth, filters=decoder_depth,
rate=1, rate=1,
weight_decay=weight_decay, weight_decay=weight_decay,
scope='decoder_conv0') scope='decoder_conv0' + scope_suffix)
decoder_features = split_separable_conv2d( decoder_features = split_separable_conv2d(
decoder_features, decoder_features,
filters=decoder_depth, filters=decoder_depth,
rate=1, rate=1,
weight_decay=weight_decay, weight_decay=weight_decay,
scope='decoder_conv1') scope='decoder_conv1' + scope_suffix)
else: else:
num_convs = 2 num_convs = 2
decoder_features = slim.repeat( decoder_features = slim.repeat(
...@@ -642,7 +697,8 @@ def refine_by_decoder(features, ...@@ -642,7 +697,8 @@ def refine_by_decoder(features,
slim.conv2d, slim.conv2d,
decoder_depth, decoder_depth,
3, 3,
scope='decoder_conv' + str(i)) scope='decoder_conv' + str(i) + scope_suffix)
decoder_stage += 1
return decoder_features return decoder_features
......
...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase):
add_image_level_feature=True, add_image_level_feature=True,
aspp_with_batch_norm=True, aspp_with_batch_norm=True,
logits_kernel_size=1, logits_kernel_size=1,
decoder_output_stride=[4],
model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test. model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test.
g = tf.Graph() g = tf.Graph()
...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase):
image_pyramid=[1.0]) image_pyramid=[1.0])
for output in outputs_to_num_classes: for output in outputs_to_num_classes:
scales_to_model_results = outputs_to_scales_to_model_results[output] scales_to_model_results = outputs_to_scales_to_model_results[output]
self.assertListEqual(scales_to_model_results.keys(), self.assertListEqual(
expected_endpoints) list(scales_to_model_results), expected_endpoints)
self.assertEqual(len(scales_to_model_results), 1) self.assertEqual(len(scales_to_model_results), 1)
......
This directory contains testing data.
# pascal_voc_seg
This folder contains data specific to pascal_voc_seg dataset. val-00000-of-00001.tfrecord contains
three randomly generated images with format defined in
tensorflow/models/research/deeplab/datasets/build_voc2012_data.py.
This diff is collapsed.
...@@ -37,7 +37,7 @@ _PASCAL = 'pascal' ...@@ -37,7 +37,7 @@ _PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset. # Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = { _DATASET_MAX_ENTRIES = {
_ADE20K: 151, _ADE20K: 151,
_CITYSCAPES: 19, _CITYSCAPES: 256,
_MAPILLARY_VISTAS: 66, _MAPILLARY_VISTAS: 66,
_PASCAL: 256, _PASCAL: 256,
} }
...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap(): ...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap():
Returns: Returns:
A colormap for visualizing segmentation results. A colormap for visualizing segmentation results.
""" """
return np.asarray([ colormap = np.zeros((256, 3), dtype=np.uint8)
[128, 64, 128], colormap[0] = [128, 64, 128]
[244, 35, 232], colormap[1] = [244, 35, 232]
[70, 70, 70], colormap[2] = [70, 70, 70]
[102, 102, 156], colormap[3] = [102, 102, 156]
[190, 153, 153], colormap[4] = [190, 153, 153]
[153, 153, 153], colormap[5] = [153, 153, 153]
[250, 170, 30], colormap[6] = [250, 170, 30]
[220, 220, 0], colormap[7] = [220, 220, 0]
[107, 142, 35], colormap[8] = [107, 142, 35]
[152, 251, 152], colormap[9] = [152, 251, 152]
[70, 130, 180], colormap[10] = [70, 130, 180]
[220, 20, 60], colormap[11] = [220, 20, 60]
[255, 0, 0], colormap[12] = [255, 0, 0]
[0, 0, 142], colormap[13] = [0, 0, 142]
[0, 0, 70], colormap[14] = [0, 0, 70]
[0, 60, 100], colormap[15] = [0, 60, 100]
[0, 80, 100], colormap[16] = [0, 80, 100]
[0, 0, 230], colormap[17] = [0, 0, 230]
[119, 11, 32], colormap[18] = [119, 11, 32]
]) return colormap
def create_mapillary_vistas_label_colormap(): def create_mapillary_vistas_label_colormap():
...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL): ...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL):
map maximum entry. map maximum entry.
""" """
if label.ndim != 2: if label.ndim != 2:
raise ValueError('Expect 2-D input label') raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]: if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
raise ValueError('label value too large.') raise ValueError(
'label value too large: {} >= {}.'.format(
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
colormap = create_label_colormap(dataset) colormap = create_label_colormap(dataset)
return colormap[label] return colormap[label]
def get_dataset_colormap_max_entries(dataset):
return _DATASET_MAX_ENTRIES[dataset]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
def _get_data(data_provider, dataset_split):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
...@@ -29,16 +29,20 @@ def save_annotation(label, ...@@ -29,16 +29,20 @@ def save_annotation(label,
save_dir, save_dir,
filename, filename,
add_colormap=True, add_colormap=True,
normalize_to_unit_values=False,
scale_values=False,
colormap_type=get_dataset_colormap.get_pascal_name()): colormap_type=get_dataset_colormap.get_pascal_name()):
"""Saves the given label to image on disk. """Saves the given label to image on disk.
Args: Args:
label: The numpy array to be saved. The data will be converted label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image. to uint8 and saved as png image.
save_dir: The directory to which the results will be saved. save_dir: String, the directory to which the results will be saved.
filename: The image filename. filename: String, the image filename.
add_colormap: Add color map to the label or not. add_colormap: Boolean, add color map to the label or not.
colormap_type: Colormap type for visualization. normalize_to_unit_values: Boolean, normalize the input values to [0, 1].
scale_values: Boolean, scale the input values to [0, 255] for visualization.
colormap_type: String, colormap type for visualization.
""" """
# Add colormap for visualizing the prediction. # Add colormap for visualizing the prediction.
if add_colormap: if add_colormap:
...@@ -46,6 +50,15 @@ def save_annotation(label, ...@@ -46,6 +50,15 @@ def save_annotation(label,
label, colormap_type) label, colormap_type)
else: else:
colored_label = label colored_label = label
if normalize_to_unit_values:
min_value = np.amin(colored_label)
max_value = np.amax(colored_label)
range_value = max_value - min_value
if range_value != 0:
colored_label = (colored_label - min_value) / range_value
if scale_values:
colored_label = 255. * colored_label
pil_image = img.fromarray(colored_label.astype(dtype=np.uint8)) pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f: with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
......
...@@ -19,7 +19,11 @@ import six ...@@ -19,7 +19,11 @@ import six
import tensorflow as tf import tensorflow as tf
from deeplab.core import preprocess_utils from deeplab.core import preprocess_utils
slim = tf.contrib.slim
def _div_maybe_zero(total_loss, num_present):
"""Normalizes the total loss with the number of present pixels."""
return tf.to_float(num_present > 0) * tf.div(total_loss,
tf.maximum(1e-5, num_present))
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label, ignore_label,
loss_weight=1.0, loss_weight=1.0,
upsample_logits=True, upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
scope=None): scope=None):
"""Adds softmax cross entropy loss for logits of each scale. """Adds softmax cross entropy loss for logits of each scale.
...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label: Integer, label to ignore. ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight. loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not. upsample_logits: Boolean, upsample logits or not.
hard_example_mining_step: An integer, the training step in which the hard
exampling mining kicks off. Note that we gradually reduce the mining
percent to the top_k_percent_pixels. For example, if
hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
< 1.0, only compute the loss for the top k percent pixels (e.g., the top
20% pixels). This is useful for hard pixel mining.
scope: String, the scope for the loss. scope: String, the scope for the loss.
Raises: Raises:
...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
scaled_labels = tf.reshape(scaled_labels, shape=[-1]) scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels, not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight ignore_label)) * loss_weight
one_hot_labels = slim.one_hot_encoding( one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0) scaled_labels, num_classes, on_value=1.0, off_value=0.0)
if top_k_percent_pixels == 1.0:
# Compute the loss for all pixels.
tf.losses.softmax_cross_entropy( tf.losses.softmax_cross_entropy(
one_hot_labels, one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]), tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask, weights=not_ignore_mask,
scope=loss_scope) scope=loss_scope)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
num_pixels = tf.to_float(tf.shape(logits)[0])
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
total_loss = tf.reduce_sum(top_k_losses)
num_present = tf.reduce_sum(
tf.to_float(tf.not_equal(top_k_losses, 0.0)))
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
def get_model_init_fn(train_logdir, def get_model_init_fn(train_logdir,
...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir, ...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir,
if not initialize_last_layer: if not initialize_last_layer:
exclude_list.extend(last_layers) exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list) variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore: if variables_to_restore:
return slim.assign_from_checkpoint_fn( init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
tf_initial_checkpoint, tf_initial_checkpoint,
variables_to_restore, variables_to_restore,
ignore_missing_vars=ignore_missing_vars) ignore_missing_vars=ignore_missing_vars)
global_step = tf.train.get_or_create_global_step()
def restore_fn(unused_scaffold, sess):
sess.run(init_op, init_feed_dict)
sess.run([global_step])
return restore_fn
return None return None
...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
""" """
gradient_multipliers = {} gradient_multipliers = {}
for var in slim.get_model_variables(): for var in tf.model_variables():
# Double the learning rate for biases. # Double the learning rate for biases.
if 'biases' in var.op.name: if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2. gradient_multipliers[var.op.name] = 2.
...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
return gradient_multipliers return gradient_multipliers
def get_model_learning_rate( def get_model_learning_rate(learning_policy,
learning_policy, base_learning_rate, learning_rate_decay_step, base_learning_rate,
learning_rate_decay_factor, training_number_of_steps, learning_power, learning_rate_decay_step,
slow_start_step, slow_start_learning_rate): learning_rate_decay_factor,
training_number_of_steps,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
"""Gets model's learning rate. """Gets model's learning rate.
Computes the model's learning rate for different learning policy. Computes the model's learning rate for different learning policy.
...@@ -181,31 +245,51 @@ def get_model_learning_rate( ...@@ -181,31 +245,51 @@ def get_model_learning_rate(
slow_start_step: Training model with small learning rate for the first slow_start_step: Training model with small learning rate for the first
few steps. few steps.
slow_start_learning_rate: The learning rate employed during slow start. slow_start_learning_rate: The learning rate employed during slow start.
slow_start_burnin_type: The burnin type for the slow start stage. Can be
`none` which means no burnin or `linear` which means the learning rate
increases linearly from slow_start_learning_rate and reaches
base_learning_rate after slow_start_steps.
Returns: Returns:
Learning rate for the specified learning policy. Learning rate for the specified learning policy.
Raises: Raises:
ValueError: If learning policy is not recognized. ValueError: If learning policy or slow start burnin type is not recognized.
""" """
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
if learning_policy == 'step': if learning_policy == 'step':
learning_rate = tf.train.exponential_decay( learning_rate = tf.train.exponential_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
learning_rate_decay_step, learning_rate_decay_step,
learning_rate_decay_factor, learning_rate_decay_factor,
staircase=True) staircase=True)
elif learning_policy == 'poly': elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay( learning_rate = tf.train.polynomial_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
training_number_of_steps, training_number_of_steps,
end_learning_rate=0, end_learning_rate=0,
power=learning_power) power=learning_power)
else: else:
raise ValueError('Unknown learning policy.') raise ValueError('Unknown learning policy.')
adjusted_slow_start_learning_rate = slow_start_learning_rate
if slow_start_burnin_type == 'linear':
# Do linear burnin. Increase linearly from slow_start_learning_rate and
# reach base_learning_rate after (global_step >= slow_start_steps).
adjusted_slow_start_learning_rate = (
slow_start_learning_rate +
(base_learning_rate - slow_start_learning_rate) *
tf.to_float(global_step) / slow_start_step)
elif slow_start_burnin_type != 'none':
raise ValueError('Unknown burnin type.')
# Employ small learning rate at the first few steps for warm start. # Employ small learning rate at the first few steps for warm start.
return tf.where(global_step < slow_start_step, slow_start_learning_rate, return tf.where(global_step < slow_start_step,
learning_rate) adjusted_slow_start_learning_rate, learning_rate)
...@@ -17,19 +17,15 @@ ...@@ -17,19 +17,15 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import math
import os.path import os.path
import time import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import segmentation_dataset from deeplab.datasets import data_generator
from deeplab.utils import input_generator
from deeplab.utils import save_annotation from deeplab.utils import save_annotation
slim = tf.contrib.slim
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names, ...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names,
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information. # Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = data_generator.Dataset(
FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir) dataset_name=FLAGS.dataset,
split_name=FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.vis_batch_size,
crop_size=FLAGS.vis_crop_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
model_variant=FLAGS.model_variant,
is_training=False,
should_shuffle=False,
should_repeat=False)
train_id_to_eval_id = None train_id_to_eval_id = None
if dataset.name == segmentation_dataset.get_cityscapes_dataset_name(): if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
tf.logging.info('Cityscapes requires converting train_id to eval_id.') tf.logging.info('Cityscapes requires converting train_id to eval_id.')
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
...@@ -204,20 +213,11 @@ def main(unused_argv): ...@@ -204,20 +213,11 @@ def main(unused_argv):
tf.logging.info('Visualizing on %s set', FLAGS.vis_split) tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph() with tf.Graph().as_default():
with g.as_default(): samples = dataset.get_one_shot_iterator().get_next()
samples = input_generator.get(dataset,
FLAGS.vis_crop_size,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.vis_crop_size, crop_size=FLAGS.vis_crop_size,
atrous_rates=FLAGS.atrous_rates, atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride) output_stride=FLAGS.output_stride)
...@@ -244,7 +244,7 @@ def main(unused_argv): ...@@ -244,7 +244,7 @@ def main(unused_argv):
# Reverse the resizing and padding operations performed in preprocessing. # Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then # First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back. # we resize the predictions back.
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
original_image_shape = tf.shape(original_image) original_image_shape = tf.shape(original_image)
predictions = tf.slice( predictions = tf.slice(
...@@ -259,40 +259,35 @@ def main(unused_argv): ...@@ -259,40 +259,35 @@ def main(unused_argv):
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True), 3) align_corners=True), 3)
tf.train.get_or_create_global_step() num_iteration = 0
saver = tf.train.Saver(slim.get_variables_to_restore()) max_num_iteration = FLAGS.max_number_of_iterations
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir, checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
init_op=tf.global_variables_initializer(), FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
summary_op=None, for checkpoint_path in checkpoints_iterator:
summary_writer=None, if max_num_iteration > 0 and num_iteration > max_num_iteration:
global_step=None, break
saver=saver) num_iteration += 1
num_batches = int(math.ceil(
dataset.num_samples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Loop to visualize the results when new checkpoint is created.
num_iters = 0
while (FLAGS.max_number_of_iterations <= 0 or
num_iters < FLAGS.max_number_of_iterations):
num_iters += 1
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info( tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint) tf.logging.info('Visualizing with model %s', checkpoint_path)
with sv.managed_session(FLAGS.master, tf.train.get_or_create_global_step()
start_standard_services=False) as sess:
sv.start_queue_runners(sess)
sv.saver.restore(sess, last_checkpoint)
scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
master=FLAGS.master,
checkpoint_filename_with_path=checkpoint_path)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=None) as sess:
batch = 0
image_id_offset = 0 image_id_offset = 0
for batch in range(num_batches):
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) while not sess.should_stop():
tf.logging.info('Visualizing batch %d', batch + 1)
_process_batch(sess=sess, _process_batch(sess=sess,
original_images=samples[common.ORIGINAL_IMAGE], original_images=samples[common.ORIGINAL_IMAGE],
semantic_predictions=predictions, semantic_predictions=predictions,
...@@ -304,14 +299,11 @@ def main(unused_argv): ...@@ -304,14 +299,11 @@ def main(unused_argv):
raw_save_dir=raw_save_dir, raw_save_dir=raw_save_dir,
train_id_to_eval_id=train_id_to_eval_id) train_id_to_eval_id=train_id_to_eval_id)
image_id_offset += FLAGS.vis_batch_size image_id_offset += FLAGS.vis_batch_size
batch += 1
tf.logging.info( tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir') flags.mark_flag_as_required('checkpoint_dir')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment