"configs/datasets/vscode:/vscode.git/clone" did not exist on "2801883351ae5dcf2d8c5426bc450245ef04f55f"
Commit 2c962110 authored by huihui-personal's avatar huihui-personal Committed by aquariusjay
Browse files

Open source deeplab changes. Check change log in README.md for detailed info. (#6231)

* \nRefactor deeplab to use MonitoredTrainingSession\n

PiperOrigin-RevId: 234237190

* Update export_model.py

* Update nas_cell.py

* Update nas_network.py

* Update train.py

* Update deeplab_demo.ipynb

* Update nas_cell.py
parent a432998c
...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_ ...@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_
image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
image dimension is all equal to 1024x2048. image dimension is all equal to 1024x2048.
___ ___
Q9: Why multi-gpu training is slow?
A: Please try to use more threads to pre-process the inputs. For, example change [num_readers = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L71) and [num_threads = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L72).
___
## References ## References
......
...@@ -32,13 +32,13 @@ sudo pip install matplotlib ...@@ -32,13 +32,13 @@ sudo pip install matplotlib
## Add Libraries to PYTHONPATH ## Add Libraries to PYTHONPATH
When running locally, the tensorflow/models/research/ and slim directories When running locally, the tensorflow/models/research/ directory should be
should be appended to PYTHONPATH. This can be done by running the following from appended to PYTHONPATH. This can be done by running the following from
tensorflow/models/research/: tensorflow/models/research/:
```bash ```bash
# From tensorflow/models/research/ # From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim export PYTHONPATH=$PYTHONPATH:`pwd`
``` ```
Note: This command needs to run from every new terminal you start. If you wish Note: This command needs to run from every new terminal you start. If you wish
......
...@@ -88,13 +88,18 @@ We provide some checkpoints that have been pretrained on ADE20K training set. ...@@ -88,13 +88,18 @@ We provide some checkpoints that have been pretrained on ADE20K training set.
Note that the model has only been pretrained on ImageNet, following the Note that the model has only been pretrained on ImageNet, following the
dataset rule. dataset rule.
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder | Input size
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: ------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: | :-----:
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 mobilenetv2_ade20k_train | MobileNet-v2 | ImageNet <br> ADE20K training set | N/A | OS = 4 | 257x257
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 | 513x513
The input dimensions of ADE20K have a huge amount of variation. We resize inputs so that the longest size is 257 for MobileNet-v2 (faster inference) and 513 for Xception_65 (better performation). Note that we also include the decoder module in the MobileNet-v2 checkpoint.
Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------: ------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------:
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB [mobilenetv2_ade20k_train](http://download.tensorflow.org/models/deeplabv3_mnv2_ade20k_train_2018_12_03.tar.gz) | 16 | [1.0] | No | 32.04% (val) | 75.41% (val) | 24.8MB
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
## Checkpoints pretrained on ImageNet ## Checkpoints pretrained on ImageNet
......
...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image, ...@@ -82,7 +82,7 @@ def preprocess_image_and_label(image,
label = tf.cast(label, tf.int32) label = tf.cast(label, tf.int32)
# Resize image and label to the desired range. # Resize image and label to the desired range.
if min_resize_value is not None or max_resize_value is not None: if min_resize_value or max_resize_value:
[processed_image, label] = ( [processed_image, label] = (
preprocess_utils.resize_to_range( preprocess_utils.resize_to_range(
image=processed_image, image=processed_image,
......
...@@ -56,7 +56,6 @@ from deeplab.core import dense_prediction_cell ...@@ -56,7 +56,6 @@ from deeplab.core import dense_prediction_cell
from deeplab.core import feature_extractor from deeplab.core import feature_extractor
from deeplab.core import utils from deeplab.core import utils
slim = tf.contrib.slim slim = tf.contrib.slim
LOGITS_SCOPE_NAME = 'logits' LOGITS_SCOPE_NAME = 'logits'
...@@ -67,9 +66,11 @@ CONCAT_PROJECTION_SCOPE = 'concat_projection' ...@@ -67,9 +66,11 @@ CONCAT_PROJECTION_SCOPE = 'concat_projection'
DECODER_SCOPE = 'decoder' DECODER_SCOPE = 'decoder'
META_ARCHITECTURE_SCOPE = 'meta_architecture' META_ARCHITECTURE_SCOPE = 'meta_architecture'
_resize_bilinear = utils.resize_bilinear
scale_dimension = utils.scale_dimension scale_dimension = utils.scale_dimension
split_separable_conv2d = utils.split_separable_conv2d split_separable_conv2d = utils.split_separable_conv2d
def get_extra_layer_scopes(last_layers_contain_logits_only=False): def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers. """Gets the scopes for extra layers.
...@@ -135,20 +136,20 @@ def predict_labels_multi_scale(images, ...@@ -135,20 +136,20 @@ def predict_labels_multi_scale(images,
for output in sorted(outputs_to_scales_to_logits): for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output] scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear( logits = _resize_bilinear(
scales_to_logits[MERGED_LOGITS_SCOPE], scales_to_logits[MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3], tf.shape(images)[1:3],
align_corners=True) scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
outputs_to_predictions[output].append( outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits), 4)) tf.expand_dims(tf.nn.softmax(logits), 4))
if add_flipped_images: if add_flipped_images:
scales_to_logits_reversed = ( scales_to_logits_reversed = (
outputs_to_scales_to_logits_reversed[output]) outputs_to_scales_to_logits_reversed[output])
logits_reversed = tf.image.resize_bilinear( logits_reversed = _resize_bilinear(
tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]), tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]),
tf.shape(images)[1:3], tf.shape(images)[1:3],
align_corners=True) scales_to_logits_reversed[MERGED_LOGITS_SCOPE].dtype)
outputs_to_predictions[output].append( outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits_reversed), 4)) tf.expand_dims(tf.nn.softmax(logits_reversed), 4))
...@@ -184,37 +185,35 @@ def predict_labels(images, model_options, image_pyramid=None): ...@@ -184,37 +185,35 @@ def predict_labels(images, model_options, image_pyramid=None):
predictions = {} predictions = {}
for output in sorted(outputs_to_scales_to_logits): for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output] scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear( logits = scales_to_logits[MERGED_LOGITS_SCOPE]
scales_to_logits[MERGED_LOGITS_SCOPE], # There are two ways to obtain the final prediction results: (1) bilinear
tf.shape(images)[1:3], # upsampling the logits followed by argmax, or (2) argmax followed by
align_corners=True) # nearest neighbor upsampling. The second option may introduce the "blocking
predictions[output] = tf.argmax(logits, 3) # effect" but is computationally efficient.
if model_options.prediction_with_upsampled_logits:
logits = _resize_bilinear(logits,
tf.shape(images)[1:3],
scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
predictions[output] = tf.argmax(logits, 3)
else:
argmax_results = tf.argmax(logits, 3)
argmax_results = tf.image.resize_nearest_neighbor(
tf.expand_dims(argmax_results, 3),
tf.shape(images)[1:3],
align_corners=True,
name='resize_prediction')
predictions[output] = tf.squeeze(argmax_results, 3)
return predictions return predictions
def _resize_bilinear(images, size, output_dtype=tf.float32):
"""Returns resized images as output_type.
Args:
images: A tensor of size [batch, height_in, width_in, channels].
size: A 1-D int32 Tensor of 2 elements: new_height, new_width. The new size
for the images.
output_dtype: The destination type.
Returns:
A tensor of size [batch, height_out, width_out, channels] as a dtype of
output_dtype.
"""
images = tf.image.resize_bilinear(images, size, align_corners=True)
return tf.cast(images, dtype=output_dtype)
def multi_scale_logits(images, def multi_scale_logits(images,
model_options, model_options,
image_pyramid, image_pyramid,
weight_decay=0.0001, weight_decay=0.0001,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Gets the logits for multi-scale inputs. """Gets the logits for multi-scale inputs.
The returned logits are all downsampled (due to max-pooling layers) The returned logits are all downsampled (due to max-pooling layers)
...@@ -227,6 +226,12 @@ def multi_scale_logits(images, ...@@ -227,6 +226,12 @@ def multi_scale_logits(images,
weight_decay: The weight decay for model variables. weight_decay: The weight decay for model variables.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
outputs_to_scales_to_logits: A map of maps from output_type (e.g., outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
...@@ -250,10 +255,15 @@ def multi_scale_logits(images, ...@@ -250,10 +255,15 @@ def multi_scale_logits(images,
crop_width = ( crop_width = (
model_options.crop_size[1] model_options.crop_size[1]
if model_options.crop_size else tf.shape(images)[2]) if model_options.crop_size else tf.shape(images)[2])
if model_options.image_pooling_crop_size:
image_pooling_crop_height = model_options.image_pooling_crop_size[0]
image_pooling_crop_width = model_options.image_pooling_crop_size[1]
# Compute the height, width for the output logits. # Compute the height, width for the output logits.
logits_output_stride = ( if model_options.decoder_output_stride:
model_options.decoder_output_stride or model_options.output_stride) logits_output_stride = min(model_options.decoder_output_stride)
else:
logits_output_stride = model_options.output_stride
logits_height = scale_dimension( logits_height = scale_dimension(
crop_height, crop_height,
...@@ -268,33 +278,45 @@ def multi_scale_logits(images, ...@@ -268,33 +278,45 @@ def multi_scale_logits(images,
for k in model_options.outputs_to_num_classes for k in model_options.outputs_to_num_classes
} }
num_channels = images.get_shape().as_list()[-1]
for image_scale in image_pyramid: for image_scale in image_pyramid:
if image_scale != 1.0: if image_scale != 1.0:
scaled_height = scale_dimension(crop_height, image_scale) scaled_height = scale_dimension(crop_height, image_scale)
scaled_width = scale_dimension(crop_width, image_scale) scaled_width = scale_dimension(crop_width, image_scale)
scaled_crop_size = [scaled_height, scaled_width] scaled_crop_size = [scaled_height, scaled_width]
scaled_images = tf.image.resize_bilinear( scaled_images = _resize_bilinear(images, scaled_crop_size, images.dtype)
images, scaled_crop_size, align_corners=True)
if model_options.crop_size: if model_options.crop_size:
scaled_images.set_shape([None, scaled_height, scaled_width, 3]) scaled_images.set_shape(
[None, scaled_height, scaled_width, num_channels])
# Adjust image_pooling_crop_size accordingly.
scaled_image_pooling_crop_size = None
if model_options.image_pooling_crop_size:
scaled_image_pooling_crop_size = [
scale_dimension(image_pooling_crop_height, image_scale),
scale_dimension(image_pooling_crop_width, image_scale)]
else: else:
scaled_crop_size = model_options.crop_size scaled_crop_size = model_options.crop_size
scaled_images = images scaled_images = images
scaled_image_pooling_crop_size = model_options.image_pooling_crop_size
updated_options = model_options._replace(crop_size=scaled_crop_size) updated_options = model_options._replace(
crop_size=scaled_crop_size,
image_pooling_crop_size=scaled_image_pooling_crop_size)
outputs_to_logits = _get_logits( outputs_to_logits = _get_logits(
scaled_images, scaled_images,
updated_options, updated_options,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=tf.AUTO_REUSE, reuse=tf.AUTO_REUSE,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
nas_training_hyper_parameters=nas_training_hyper_parameters)
# Resize the logits to have the same dimension before merging. # Resize the logits to have the same dimension before merging.
for output in sorted(outputs_to_logits): for output in sorted(outputs_to_logits):
outputs_to_logits[output] = tf.image.resize_bilinear( outputs_to_logits[output] = _resize_bilinear(
outputs_to_logits[output], [logits_height, logits_width], outputs_to_logits[output], [logits_height, logits_width],
align_corners=True) outputs_to_logits[output].dtype)
# Return when only one input scale. # Return when only one input scale.
if len(image_pyramid) == 1: if len(image_pyramid) == 1:
...@@ -330,7 +352,8 @@ def extract_features(images, ...@@ -330,7 +352,8 @@ def extract_features(images,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Extracts features by the particular model_variant. """Extracts features by the particular model_variant.
Args: Args:
...@@ -340,6 +363,12 @@ def extract_features(images, ...@@ -340,6 +363,12 @@ def extract_features(images,
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
concat_logits: A tensor of size [batch, feature_height, feature_width, concat_logits: A tensor of size [batch, feature_height, feature_width,
...@@ -354,10 +383,16 @@ def extract_features(images, ...@@ -354,10 +383,16 @@ def extract_features(images,
multi_grid=model_options.multi_grid, multi_grid=model_options.multi_grid,
model_variant=model_options.model_variant, model_variant=model_options.model_variant,
depth_multiplier=model_options.depth_multiplier, depth_multiplier=model_options.depth_multiplier,
divisible_by=model_options.divisible_by,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) preprocessed_images_dtype=model_options.preprocessed_images_dtype,
fine_tune_batch_norm=fine_tune_batch_norm,
nas_stem_output_num_conv_filters=(
model_options.nas_stem_output_num_conv_filters),
nas_training_hyper_parameters=nas_training_hyper_parameters,
use_bounded_activation=model_options.use_bounded_activation)
if not model_options.aspp_with_batch_norm: if not model_options.aspp_with_batch_norm:
return features, end_points return features, end_points
...@@ -367,7 +402,7 @@ def extract_features(images, ...@@ -367,7 +402,7 @@ def extract_features(images,
dense_prediction_layer = dense_prediction_cell.DensePredictionCell( dense_prediction_layer = dense_prediction_cell.DensePredictionCell(
config=model_options.dense_prediction_cell_config, config=model_options.dense_prediction_cell_config,
hparams={ hparams={
'conv_rate_multiplier': 16 // model_options.output_stride, 'conv_rate_multiplier': 16 // model_options.output_stride,
}) })
concat_logits = dense_prediction_layer.build_cell( concat_logits = dense_prediction_layer.build_cell(
features, features,
...@@ -380,21 +415,24 @@ def extract_features(images, ...@@ -380,21 +415,24 @@ def extract_features(images,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm)
return concat_logits, end_points return concat_logits, end_points
else: else:
# The following codes employ the DeepLabv3 ASPP module. Note that We # The following codes employ the DeepLabv3 ASPP module. Note that we
# could express the ASPP module as one particular dense prediction # could express the ASPP module as one particular dense prediction
# cell architecture. We do not do so but leave the following codes in # cell architecture. We do not do so but leave the following codes
# order for backward compatibility. # for backward compatibility.
batch_norm_params = { batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm, 'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997, 'decay': 0.9997,
'epsilon': 1e-5, 'epsilon': 1e-5,
'scale': True, 'scale': True,
} }
activation_fn = (
tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu)
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], [slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu, activation_fn=activation_fn,
normalizer_fn=slim.batch_norm, normalizer_fn=slim.batch_norm,
padding='SAME', padding='SAME',
stride=1, stride=1,
...@@ -416,7 +454,8 @@ def extract_features(images, ...@@ -416,7 +454,8 @@ def extract_features(images,
image_pooling_crop_size[1], image_pooling_crop_size[1],
1. / model_options.output_stride) 1. / model_options.output_stride)
image_feature = slim.avg_pool2d( image_feature = slim.avg_pool2d(
features, [pool_height, pool_width], [1, 1], padding='VALID') features, [pool_height, pool_width],
model_options.image_pooling_stride, padding='VALID')
resize_height = scale_dimension( resize_height = scale_dimension(
model_options.crop_size[0], model_options.crop_size[0],
1. / model_options.output_stride) 1. / model_options.output_stride)
...@@ -483,7 +522,8 @@ def _get_logits(images, ...@@ -483,7 +522,8 @@ def _get_logits(images,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
nas_training_hyper_parameters=None):
"""Gets the logits by atrous/image spatial pyramid pooling. """Gets the logits by atrous/image spatial pyramid pooling.
Args: Args:
...@@ -493,6 +533,12 @@ def _get_logits(images, ...@@ -493,6 +533,12 @@ def _get_logits(images,
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
nas_training_hyper_parameters: A dictionary storing hyper-parameters for
training nas models. Its keys are:
- `drop_path_keep_prob`: Probability to keep each path in the cell when
training.
- `total_training_steps`: Total training steps to help drop path
probability calculation.
Returns: Returns:
outputs_to_logits: A map from output_type to logits. outputs_to_logits: A map from output_type to logits.
...@@ -503,29 +549,22 @@ def _get_logits(images, ...@@ -503,29 +549,22 @@ def _get_logits(images,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
nas_training_hyper_parameters=nas_training_hyper_parameters)
if model_options.decoder_output_stride is not None: if model_options.decoder_output_stride is not None:
if model_options.crop_size is None:
height = tf.shape(images)[1]
width = tf.shape(images)[2]
else:
height, width = model_options.crop_size
decoder_height = scale_dimension(height,
1.0 / model_options.decoder_output_stride)
decoder_width = scale_dimension(width,
1.0 / model_options.decoder_output_stride)
features = refine_by_decoder( features = refine_by_decoder(
features, features,
end_points, end_points,
decoder_height=decoder_height, crop_size=model_options.crop_size,
decoder_width=decoder_width, decoder_output_stride=model_options.decoder_output_stride,
decoder_use_separable_conv=model_options.decoder_use_separable_conv, decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant, model_variant=model_options.model_variant,
weight_decay=weight_decay, weight_decay=weight_decay,
reuse=reuse, reuse=reuse,
is_training=is_training, is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm,
use_bounded_activation=model_options.use_bounded_activation)
outputs_to_logits = {} outputs_to_logits = {}
for output in sorted(model_options.outputs_to_num_classes): for output in sorted(model_options.outputs_to_num_classes):
...@@ -544,14 +583,15 @@ def _get_logits(images, ...@@ -544,14 +583,15 @@ def _get_logits(images,
def refine_by_decoder(features, def refine_by_decoder(features,
end_points, end_points,
decoder_height, crop_size=None,
decoder_width, decoder_output_stride=None,
decoder_use_separable_conv=False, decoder_use_separable_conv=False,
model_variant=None, model_variant=None,
weight_decay=0.0001, weight_decay=0.0001,
reuse=None, reuse=None,
is_training=False, is_training=False,
fine_tune_batch_norm=False): fine_tune_batch_norm=False,
use_bounded_activation=False):
"""Adds the decoder to obtain sharper segmentation results. """Adds the decoder to obtain sharper segmentation results.
Args: Args:
...@@ -559,19 +599,28 @@ def refine_by_decoder(features, ...@@ -559,19 +599,28 @@ def refine_by_decoder(features,
features_channels]. features_channels].
end_points: A dictionary from components of the network to the corresponding end_points: A dictionary from components of the network to the corresponding
activation. activation.
decoder_height: The height of decoder feature maps. crop_size: A tuple [crop_height, crop_width] specifying whole patch crop
decoder_width: The width of decoder feature maps. size.
decoder_output_stride: A list of integers specifying the output stride of
low-level features used in the decoder module.
decoder_use_separable_conv: Employ separable convolution for decoder or not. decoder_use_separable_conv: Employ separable convolution for decoder or not.
model_variant: Model variant for feature extraction. model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables. weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not. reuse: Reuse the model variables or not.
is_training: Is training or not. is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
Returns: Returns:
Decoder output with size [batch, decoder_height, decoder_width, Decoder output with size [batch, decoder_height, decoder_width,
decoder_channels]. decoder_channels].
Raises:
ValueError: If crop_size is None.
""" """
if crop_size is None:
raise ValueError('crop_size must be provided when using decoder.')
batch_norm_params = { batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm, 'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997, 'decay': 0.9997,
...@@ -582,25 +631,28 @@ def refine_by_decoder(features, ...@@ -582,25 +631,28 @@ def refine_by_decoder(features,
with slim.arg_scope( with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], [slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay), weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu6 if use_bounded_activation else tf.nn.relu,
normalizer_fn=slim.batch_norm, normalizer_fn=slim.batch_norm,
padding='SAME', padding='SAME',
stride=1, stride=1,
reuse=reuse): reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]): with tf.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]):
feature_list = feature_extractor.networks_to_feature_maps[ decoder_features = features
model_variant][feature_extractor.DECODER_END_POINTS] decoder_stage = 0
if feature_list is None: scope_suffix = ''
tf.logging.info('Not found any decoder end points.') for output_stride in decoder_output_stride:
return features feature_list = feature_extractor.networks_to_feature_maps[
else: model_variant][
decoder_features = features feature_extractor.DECODER_END_POINTS][output_stride]
# If only one decoder stage, we do not change the scope name in
# order for backward compactibility.
if decoder_stage:
scope_suffix = '_{}'.format(decoder_stage)
for i, name in enumerate(feature_list): for i, name in enumerate(feature_list):
decoder_features_list = [decoder_features] decoder_features_list = [decoder_features]
# MobileNet and NAS variants use different naming convention.
# MobileNet variants use different naming convention. if 'mobilenet' in model_variant or model_variant.startswith('nas'):
if 'mobilenet' in model_variant:
feature_name = name feature_name = name
else: else:
feature_name = '{}/{}'.format( feature_name = '{}/{}'.format(
...@@ -610,11 +662,14 @@ def refine_by_decoder(features, ...@@ -610,11 +662,14 @@ def refine_by_decoder(features,
end_points[feature_name], end_points[feature_name],
48, 48,
1, 1,
scope='feature_projection' + str(i))) scope='feature_projection' + str(i) + scope_suffix))
# Determine the output size.
decoder_height = scale_dimension(crop_size[0], 1.0 / output_stride)
decoder_width = scale_dimension(crop_size[1], 1.0 / output_stride)
# Resize to decoder_height/decoder_width. # Resize to decoder_height/decoder_width.
for j, feature in enumerate(decoder_features_list): for j, feature in enumerate(decoder_features_list):
decoder_features_list[j] = tf.image.resize_bilinear( decoder_features_list[j] = _resize_bilinear(
feature, [decoder_height, decoder_width], align_corners=True) feature, [decoder_height, decoder_width], feature.dtype)
h = (None if isinstance(decoder_height, tf.Tensor) h = (None if isinstance(decoder_height, tf.Tensor)
else decoder_height) else decoder_height)
w = (None if isinstance(decoder_width, tf.Tensor) w = (None if isinstance(decoder_width, tf.Tensor)
...@@ -627,13 +682,13 @@ def refine_by_decoder(features, ...@@ -627,13 +682,13 @@ def refine_by_decoder(features,
filters=decoder_depth, filters=decoder_depth,
rate=1, rate=1,
weight_decay=weight_decay, weight_decay=weight_decay,
scope='decoder_conv0') scope='decoder_conv0' + scope_suffix)
decoder_features = split_separable_conv2d( decoder_features = split_separable_conv2d(
decoder_features, decoder_features,
filters=decoder_depth, filters=decoder_depth,
rate=1, rate=1,
weight_decay=weight_decay, weight_decay=weight_decay,
scope='decoder_conv1') scope='decoder_conv1' + scope_suffix)
else: else:
num_convs = 2 num_convs = 2
decoder_features = slim.repeat( decoder_features = slim.repeat(
...@@ -642,8 +697,9 @@ def refine_by_decoder(features, ...@@ -642,8 +697,9 @@ def refine_by_decoder(features,
slim.conv2d, slim.conv2d,
decoder_depth, decoder_depth,
3, 3,
scope='decoder_conv' + str(i)) scope='decoder_conv' + str(i) + scope_suffix)
return decoder_features decoder_stage += 1
return decoder_features
def get_branch_logits(features, def get_branch_logits(features,
......
...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase):
add_image_level_feature=True, add_image_level_feature=True,
aspp_with_batch_norm=True, aspp_with_batch_norm=True,
logits_kernel_size=1, logits_kernel_size=1,
decoder_output_stride=[4],
model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test. model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test.
g = tf.Graph() g = tf.Graph()
...@@ -116,16 +117,16 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -116,16 +117,16 @@ class DeeplabModelTest(tf.test.TestCase):
outputs_to_num_classes = {'semantic': 2} outputs_to_num_classes = {'semantic': 2}
expected_endpoints = ['merged_logits'] expected_endpoints = ['merged_logits']
dense_prediction_cell_config = [ dense_prediction_cell_config = [
{'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1}, {'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1},
{'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0}, {'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0},
] ]
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes, outputs_to_num_classes,
crop_size, crop_size,
output_stride=16)._replace( output_stride=16)._replace(
aspp_with_batch_norm=True, aspp_with_batch_norm=True,
model_variant='mobilenet_v2', model_variant='mobilenet_v2',
dense_prediction_cell_config=dense_prediction_cell_config) dense_prediction_cell_config=dense_prediction_cell_config)
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
with self.test_session(graph=g): with self.test_session(graph=g):
...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase): ...@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase):
image_pyramid=[1.0]) image_pyramid=[1.0])
for output in outputs_to_num_classes: for output in outputs_to_num_classes:
scales_to_model_results = outputs_to_scales_to_model_results[output] scales_to_model_results = outputs_to_scales_to_model_results[output]
self.assertListEqual(scales_to_model_results.keys(), self.assertListEqual(
expected_endpoints) list(scales_to_model_results), expected_endpoints)
self.assertEqual(len(scales_to_model_results), 1) self.assertEqual(len(scales_to_model_results), 1)
......
This directory contains testing data.
# pascal_voc_seg
This folder contains data specific to pascal_voc_seg dataset. val-00000-of-00001.tfrecord contains
three randomly generated images with format defined in
tensorflow/models/research/deeplab/datasets/build_voc2012_data.py.
...@@ -19,19 +19,13 @@ See model.py for more details and usage. ...@@ -19,19 +19,13 @@ See model.py for more details and usage.
import six import six
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import math_ops
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import segmentation_dataset from deeplab.datasets import data_generator
from deeplab.utils import input_generator
from deeplab.utils import train_utils from deeplab.utils import train_utils
from deployment import model_deploy
slim = tf.contrib.slim
prefetch_queue = slim.prefetch_queue
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# Settings for multi-GPUs/multi-replicas training. # Settings for multi-GPUs/multi-replicas training.
...@@ -45,9 +39,10 @@ flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.') ...@@ -45,9 +39,10 @@ flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
flags.DEFINE_integer('startup_delay_steps', 15, flags.DEFINE_integer('startup_delay_steps', 15,
'Number of training steps between replicas startup.') 'Number of training steps between replicas startup.')
flags.DEFINE_integer('num_ps_tasks', 0, flags.DEFINE_integer(
'The number of parameter servers. If the value is 0, then ' 'num_ps_tasks', 0,
'the parameters are handled locally by the worker.') 'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
...@@ -67,9 +62,15 @@ flags.DEFINE_integer('save_interval_secs', 1200, ...@@ -67,9 +62,15 @@ flags.DEFINE_integer('save_interval_secs', 1200,
flags.DEFINE_integer('save_summaries_secs', 600, flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.') 'How often, in seconds, we compute the summaries.')
flags.DEFINE_boolean('save_summaries_images', False, flags.DEFINE_boolean(
'Save sample inputs, labels, and semantic predictions as ' 'save_summaries_images', False,
'images to summary.') 'Save sample inputs, labels, and semantic predictions as '
'images to summary.')
# Settings for profiling.
flags.DEFINE_string('profile_logdir', None,
'Where the profile files are stored.')
# Settings for training strategy. # Settings for training strategy.
...@@ -109,13 +110,20 @@ flags.DEFINE_float('weight_decay', 0.00004, ...@@ -109,13 +110,20 @@ flags.DEFINE_float('weight_decay', 0.00004,
flags.DEFINE_multi_integer('train_crop_size', [513, 513], flags.DEFINE_multi_integer('train_crop_size', [513, 513],
'Image crop size [height, width] during training.') 'Image crop size [height, width] during training.')
flags.DEFINE_float('last_layer_gradient_multiplier', 1.0, flags.DEFINE_float(
'The gradient multiplier for last layers, which is used to ' 'last_layer_gradient_multiplier', 1.0,
'boost the gradient of last layers if the value > 1.') 'The gradient multiplier for last layers, which is used to '
'boost the gradient of last layers if the value > 1.')
flags.DEFINE_boolean('upsample_logits', True, flags.DEFINE_boolean('upsample_logits', True,
'Upsample logits during training.') 'Upsample logits during training.')
# Hyper-parameters for NAS training strategy.
flags.DEFINE_float(
'drop_path_keep_prob', 1.0,
'Probability to keep each path in the NAS cell when training.')
# Settings for fine-tuning the network. # Settings for fine-tuning the network.
flags.DEFINE_string('tf_initial_checkpoint', None, flags.DEFINE_string('tf_initial_checkpoint', None,
...@@ -157,6 +165,22 @@ flags.DEFINE_multi_integer('atrous_rates', None, ...@@ -157,6 +165,22 @@ flags.DEFINE_multi_integer('atrous_rates', None,
flags.DEFINE_integer('output_stride', 16, flags.DEFINE_integer('output_stride', 16,
'The ratio of input to output spatial resolution.') 'The ratio of input to output spatial resolution.')
# Hard example mining related flags.
flags.DEFINE_integer(
'hard_example_mining_step', 0,
'The training step in which exact hard example mining kicks off. Note we '
'gradually reduce the mining percent to the specified '
'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
'100% to 25% until 100K steps after which we only mine top 25% pixels.')
flags.DEFINE_float(
'top_k_percent_pixels', 1.0,
'The top k percent pixels (in terms of the loss values) used to compute '
'loss during training. This is useful for hard pixel mining.')
# Dataset settings. # Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg', flags.DEFINE_string('dataset', 'pascal_voc_seg',
'Name of the segmentation dataset.') 'Name of the segmentation dataset.')
...@@ -167,50 +191,44 @@ flags.DEFINE_string('train_split', 'train', ...@@ -167,50 +191,44 @@ flags.DEFINE_string('train_split', 'train',
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.') flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label): def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
"""Builds a clone of DeepLab. """Builds a clone of DeepLab.
Args: Args:
inputs_queue: A prefetch queue for images and labels. iterator: An iterator of type tf.data.Iterator for images and labels.
outputs_to_num_classes: A map from output type to the number of classes. outputs_to_num_classes: A map from output type to the number of classes. For
For example, for the task of semantic segmentation with 21 semantic example, for the task of semantic segmentation with 21 semantic classes,
classes, we would have outputs_to_num_classes['semantic'] = 21. we would have outputs_to_num_classes['semantic'] = 21.
ignore_label: Ignore label. ignore_label: Ignore label.
Returns:
A map of maps from output_type (e.g., semantic prediction) to a
dictionary of multi-scale logits names to logits. For each output_type,
the dictionary has keys which correspond to the scales and values which
correspond to the logits. For example, if `scales` equals [1.0, 1.5],
then the keys would include 'merged_logits', 'logits_1.00' and
'logits_1.50'.
""" """
samples = inputs_queue.dequeue() samples = iterator.get_next()
# Add name to input and label nodes so we can add to summary. # Add name to input and label nodes so we can add to summary.
samples[common.IMAGE] = tf.identity( samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
samples[common.IMAGE], name=common.IMAGE) samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
samples[common.LABEL] = tf.identity(
samples[common.LABEL], name=common.LABEL)
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes=outputs_to_num_classes, outputs_to_num_classes=outputs_to_num_classes,
crop_size=FLAGS.train_crop_size, crop_size=FLAGS.train_crop_size,
atrous_rates=FLAGS.atrous_rates, atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride) output_stride=FLAGS.output_stride)
outputs_to_scales_to_logits = model.multi_scale_logits( outputs_to_scales_to_logits = model.multi_scale_logits(
samples[common.IMAGE], samples[common.IMAGE],
model_options=model_options, model_options=model_options,
image_pyramid=FLAGS.image_pyramid, image_pyramid=FLAGS.image_pyramid,
weight_decay=FLAGS.weight_decay, weight_decay=FLAGS.weight_decay,
is_training=True, is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
nas_training_hyper_parameters={
'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
'total_training_steps': FLAGS.training_number_of_steps,
})
# Add name to graph node so we can add to summary. # Add name to graph node so we can add to summary.
output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
output_type_dict[model.MERGED_LOGITS_SCOPE], output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)
name=common.OUTPUT_TYPE)
for output, num_classes in six.iteritems(outputs_to_num_classes): for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale( train_utils.add_softmax_cross_entropy_loss_for_each_scale(
...@@ -220,175 +238,262 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label): ...@@ -220,175 +238,262 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
ignore_label, ignore_label,
loss_weight=1.0, loss_weight=1.0,
upsample_logits=FLAGS.upsample_logits, upsample_logits=FLAGS.upsample_logits,
hard_example_mining_step=FLAGS.hard_example_mining_step,
top_k_percent_pixels=FLAGS.top_k_percent_pixels,
scope=output) scope=output)
return outputs_to_scales_to_logits # Log the summary
_log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
output_type_dict[model.MERGED_LOGITS_SCOPE])
def main(unused_argv): def _tower_loss(iterator, num_of_classes, ignore_label, scope, reuse_variable):
tf.logging.set_verbosity(tf.logging.INFO) """Calculates the total loss on a single tower running the deeplab model.
# Set up deployment (i.e., multi-GPUs and/or multi-replicas).
config = model_deploy.DeploymentConfig( Args:
num_clones=FLAGS.num_clones, iterator: An iterator of type tf.data.Iterator for images and labels.
clone_on_cpu=FLAGS.clone_on_cpu, num_of_classes: Number of classes for the dataset.
replica_id=FLAGS.task, ignore_label: Ignore label for the dataset.
num_replicas=FLAGS.num_replicas, scope: Unique prefix string identifying the deeplab tower.
num_ps_tasks=FLAGS.num_ps_tasks) reuse_variable: If the variable should be reused.
Returns:
The total loss for a batch of data.
"""
with tf.variable_scope(
tf.get_variable_scope(), reuse=True if reuse_variable else None):
_build_deeplab(iterator, {common.OUTPUT_TYPE: num_of_classes}, ignore_label)
# Split the batch across GPUs. losses = tf.losses.get_losses(scope=scope)
assert FLAGS.train_batch_size % config.num_clones == 0, ( for loss in losses:
'Training batch size not divisble by number of clones (GPUs).') tf.summary.scalar('Losses/%s' % loss.op.name, loss)
clone_batch_size = FLAGS.train_batch_size // config.num_clones regularization_loss = tf.losses.get_regularization_loss(scope=scope)
tf.summary.scalar('Losses/%s' % regularization_loss.op.name,
regularization_loss)
# Get dataset-dependent information. total_loss = tf.add_n([tf.add_n(losses), regularization_loss])
dataset = segmentation_dataset.get_dataset( return total_loss
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
def _average_gradients(tower_grads):
"""Calculates average of gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list is
over individual gradients. The inner list is over the gradient calculation
for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been summed
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads, variables = zip(*grad_and_vars)
grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0)
# All vars are of the same value, using the first tower here.
average_grads.append((grad, variables[0]))
return average_grads
def _log_summaries(input_image, label, num_of_classes, output):
"""Logs the summaries for the model.
Args:
input_image: Input image of the model. Its shape is [batch_size, height,
width, channel].
label: Label of the image. Its shape is [batch_size, height, width].
num_of_classes: The number of classes of the dataset.
output: Output of the model. Its shape is [batch_size, height, width].
"""
# Add summaries for model variables.
for model_var in tf.model_variables():
tf.summary.histogram(model_var.op.name, model_var)
# Add summaries for images, labels, semantic predictions.
if FLAGS.save_summaries_images:
tf.summary.image('samples/%s' % common.IMAGE, input_image)
# Scale up summary image pixel values for better visualization.
pixel_scaling = max(1, 255 // num_of_classes)
summary_label = tf.cast(label * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s' % common.LABEL, summary_label)
predictions = tf.expand_dims(tf.argmax(output, 3), -1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
"""Trains the deeplab model.
Args:
iterator: An iterator of type tf.data.Iterator for images and labels.
num_of_classes: Number of classes for the dataset.
ignore_label: Ignore label for the dataset.
Returns:
train_tensor: A tensor to update the model variables.
summary_op: An operation to log the summaries.
"""
global_step = tf.train.get_or_create_global_step()
summaries = []
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy, FLAGS.base_learning_rate,
FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
FLAGS.training_number_of_steps, FLAGS.learning_power,
FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
summaries.append(tf.summary.scalar('learning_rate', learning_rate))
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
tower_grads = []
tower_summaries = None
for i in xrange(FLAGS.num_clones):
with tf.device('/gpu:%d' % i):
with tf.name_scope('clone_%d' % i) as scope:
loss = _tower_loss(
iterator=iterator,
num_of_classes=num_of_classes,
ignore_label=ignore_label,
scope=scope,
reuse_variable=(i != 0))
grads = optimizer.compute_gradients(loss)
tower_grads.append(grads)
# Retain the summaries from the first tower.
if not i:
tower_summaries = tf.summary.merge_all(scope=scope)
with tf.device('/cpu:0'):
grads_and_vars = _average_gradients(tower_grads)
if tower_summaries is not None:
summaries.append(tower_summaries)
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)
if grad_mult:
grads_and_vars = tf.contrib.training.multiply_gradients(
grads_and_vars, grad_mult)
# Create gradient update op.
grad_updates = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
# Gather update_ops. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
# Print total loss to the terminal.
# This implementation is mirrored from tf.slim.summaries.
should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
total_loss = tf.cond(
should_log,
lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
lambda: total_loss)
summaries.append(tf.summary.scalar('total_loss', total_loss))
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
summary_op = tf.summary.merge(summaries)
return train_tensor, summary_op
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.gfile.MakeDirs(FLAGS.train_logdir) tf.gfile.MakeDirs(FLAGS.train_logdir)
tf.logging.info('Training on %s set', FLAGS.train_split) tf.logging.info('Training on %s set', FLAGS.train_split)
with tf.Graph().as_default() as graph: graph = tf.Graph()
with tf.device(config.inputs_device()): with graph.as_default():
samples = input_generator.get( with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
dataset, assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
FLAGS.train_crop_size, 'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size, clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.train_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=clone_batch_size,
crop_size=FLAGS.train_crop_size,
min_resize_value=FLAGS.min_resize_value, min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value, max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor, resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor, min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor, max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size, scale_factor_step_size=FLAGS.scale_factor_step_size,
dataset_split=FLAGS.train_split, model_variant=FLAGS.model_variant,
num_readers=2,
is_training=True, is_training=True,
model_variant=FLAGS.model_variant) should_shuffle=True,
inputs_queue = prefetch_queue.prefetch_queue( should_repeat=True)
samples, capacity=128 * config.num_clones)
train_tensor, summary_op = _train_deeplab_model(
# Create the global step on the device storing the variables. dataset.get_one_shot_iterator(), dataset.num_of_classes,
with tf.device(config.variables_device()): dataset.ignore_label)
global_step = tf.train.get_or_create_global_step()
# Soft placement allows placing on CPU ops without GPU implementation.
# Define the model and create clones. session_config = tf.ConfigProto(
model_fn = _build_deeplab allow_soft_placement=True, log_device_placement=False)
model_args = (inputs_queue, {
common.OUTPUT_TYPE: dataset.num_classes
}, dataset.ignore_label)
clones = model_deploy.create_clones(config, model_fn, args=model_args)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
first_clone_scope = config.clone_scope(0)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
# Gather initial summaries.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
# Add summaries for model variables.
for model_var in slim.get_model_variables():
summaries.add(tf.summary.histogram(model_var.op.name, model_var))
# Add summaries for images, labels, semantic predictions
if FLAGS.save_summaries_images:
summary_image = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
summaries.add(
tf.summary.image('samples/%s' % common.IMAGE, summary_image))
first_clone_label = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
# Scale up summary image pixel values for better visualization.
pixel_scaling = max(1, 255 // dataset.num_classes)
summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image('samples/%s' % common.LABEL, summary_label))
first_clone_output = graph.get_tensor_by_name(
('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
summaries.add(
tf.summary.image(
'samples/%s' % common.OUTPUT_TYPE, summary_predictions))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
# Build the optimizer based on the device specification.
with tf.device(config.optimizer_device()):
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy, FLAGS.base_learning_rate,
FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
FLAGS.training_number_of_steps, FLAGS.learning_power,
FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
for variable in slim.get_model_variables():
summaries.add(tf.summary.histogram(variable.op.name, variable))
with tf.device(config.variables_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones(
clones, optimizer)
total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes( last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only) FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers( init_fn = None
last_layers, FLAGS.last_layer_gradient_multiplier) if FLAGS.tf_initial_checkpoint:
if grad_mult: init_fn = train_utils.get_model_init_fn(
grads_and_vars = slim.learning.multiply_gradients(
grads_and_vars, grad_mult)
# Create gradient update op.
grad_updates = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(
tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries))
# Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
# Start the training.
slim.learning.train(
train_tensor,
logdir=FLAGS.train_logdir,
log_every_n_steps=FLAGS.log_steps,
master=FLAGS.master,
number_of_steps=FLAGS.training_number_of_steps,
is_chief=(FLAGS.task == 0),
session_config=session_config,
startup_delay_steps=startup_delay_steps,
init_fn=train_utils.get_model_init_fn(
FLAGS.train_logdir, FLAGS.train_logdir,
FLAGS.tf_initial_checkpoint, FLAGS.tf_initial_checkpoint,
FLAGS.initialize_last_layer, FLAGS.initialize_last_layer,
last_layers, last_layers,
ignore_missing_vars=True), ignore_missing_vars=True)
summary_op=summary_op,
save_summaries_secs=FLAGS.save_summaries_secs, scaffold = tf.train.Scaffold(
save_interval_secs=FLAGS.save_interval_secs) init_fn=init_fn,
summary_op=summary_op,
)
stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps)
profile_dir = FLAGS.profile_logdir
if profile_dir is not None:
tf.gfile.MakeDirs(profile_dir)
with tf.contrib.tfprof.ProfileContext(
enabled=profile_dir is not None, profile_dir=profile_dir):
with tf.train.MonitoredTrainingSession(
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
config=session_config,
scaffold=scaffold,
checkpoint_dir=FLAGS.train_logdir,
log_step_count_steps=FLAGS.log_steps,
save_summaries_steps=FLAGS.save_summaries_secs,
save_checkpoint_secs=FLAGS.save_interval_secs,
hooks=[stop_hook]) as sess:
while not sess.should_stop():
sess.run([train_tensor])
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('train_logdir') flags.mark_flag_as_required('train_logdir')
flags.mark_flag_as_required('tf_initial_checkpoint')
flags.mark_flag_as_required('dataset_dir') flags.mark_flag_as_required('dataset_dir')
tf.app.run() tf.app.run()
...@@ -37,7 +37,7 @@ _PASCAL = 'pascal' ...@@ -37,7 +37,7 @@ _PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset. # Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = { _DATASET_MAX_ENTRIES = {
_ADE20K: 151, _ADE20K: 151,
_CITYSCAPES: 19, _CITYSCAPES: 256,
_MAPILLARY_VISTAS: 66, _MAPILLARY_VISTAS: 66,
_PASCAL: 256, _PASCAL: 256,
} }
...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap(): ...@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap():
Returns: Returns:
A colormap for visualizing segmentation results. A colormap for visualizing segmentation results.
""" """
return np.asarray([ colormap = np.zeros((256, 3), dtype=np.uint8)
[128, 64, 128], colormap[0] = [128, 64, 128]
[244, 35, 232], colormap[1] = [244, 35, 232]
[70, 70, 70], colormap[2] = [70, 70, 70]
[102, 102, 156], colormap[3] = [102, 102, 156]
[190, 153, 153], colormap[4] = [190, 153, 153]
[153, 153, 153], colormap[5] = [153, 153, 153]
[250, 170, 30], colormap[6] = [250, 170, 30]
[220, 220, 0], colormap[7] = [220, 220, 0]
[107, 142, 35], colormap[8] = [107, 142, 35]
[152, 251, 152], colormap[9] = [152, 251, 152]
[70, 130, 180], colormap[10] = [70, 130, 180]
[220, 20, 60], colormap[11] = [220, 20, 60]
[255, 0, 0], colormap[12] = [255, 0, 0]
[0, 0, 142], colormap[13] = [0, 0, 142]
[0, 0, 70], colormap[14] = [0, 0, 70]
[0, 60, 100], colormap[15] = [0, 60, 100]
[0, 80, 100], colormap[16] = [0, 80, 100]
[0, 0, 230], colormap[17] = [0, 0, 230]
[119, 11, 32], colormap[18] = [119, 11, 32]
]) return colormap
def create_mapillary_vistas_label_colormap(): def create_mapillary_vistas_label_colormap():
...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL): ...@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL):
map maximum entry. map maximum entry.
""" """
if label.ndim != 2: if label.ndim != 2:
raise ValueError('Expect 2-D input label') raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]: if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
raise ValueError('label value too large.') raise ValueError(
'label value too large: {} >= {}.'.format(
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
colormap = create_label_colormap(dataset) colormap = create_label_colormap(dataset)
return colormap[label] return colormap[label]
def get_dataset_colormap_max_entries(dataset):
return _DATASET_MAX_ENTRIES[dataset]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
def _get_data(data_provider, dataset_split):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
...@@ -29,16 +29,20 @@ def save_annotation(label, ...@@ -29,16 +29,20 @@ def save_annotation(label,
save_dir, save_dir,
filename, filename,
add_colormap=True, add_colormap=True,
normalize_to_unit_values=False,
scale_values=False,
colormap_type=get_dataset_colormap.get_pascal_name()): colormap_type=get_dataset_colormap.get_pascal_name()):
"""Saves the given label to image on disk. """Saves the given label to image on disk.
Args: Args:
label: The numpy array to be saved. The data will be converted label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image. to uint8 and saved as png image.
save_dir: The directory to which the results will be saved. save_dir: String, the directory to which the results will be saved.
filename: The image filename. filename: String, the image filename.
add_colormap: Add color map to the label or not. add_colormap: Boolean, add color map to the label or not.
colormap_type: Colormap type for visualization. normalize_to_unit_values: Boolean, normalize the input values to [0, 1].
scale_values: Boolean, scale the input values to [0, 255] for visualization.
colormap_type: String, colormap type for visualization.
""" """
# Add colormap for visualizing the prediction. # Add colormap for visualizing the prediction.
if add_colormap: if add_colormap:
...@@ -46,6 +50,15 @@ def save_annotation(label, ...@@ -46,6 +50,15 @@ def save_annotation(label,
label, colormap_type) label, colormap_type)
else: else:
colored_label = label colored_label = label
if normalize_to_unit_values:
min_value = np.amin(colored_label)
max_value = np.amax(colored_label)
range_value = max_value - min_value
if range_value != 0:
colored_label = (colored_label - min_value) / range_value
if scale_values:
colored_label = 255. * colored_label
pil_image = img.fromarray(colored_label.astype(dtype=np.uint8)) pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f: with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
......
...@@ -19,7 +19,11 @@ import six ...@@ -19,7 +19,11 @@ import six
import tensorflow as tf import tensorflow as tf
from deeplab.core import preprocess_utils from deeplab.core import preprocess_utils
slim = tf.contrib.slim
def _div_maybe_zero(total_loss, num_present):
"""Normalizes the total loss with the number of present pixels."""
return tf.to_float(num_present > 0) * tf.div(total_loss,
tf.maximum(1e-5, num_present))
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label, ignore_label,
loss_weight=1.0, loss_weight=1.0,
upsample_logits=True, upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
scope=None): scope=None):
"""Adds softmax cross entropy loss for logits of each scale. """Adds softmax cross entropy loss for logits of each scale.
...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label: Integer, label to ignore. ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight. loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not. upsample_logits: Boolean, upsample logits or not.
hard_example_mining_step: An integer, the training step in which the hard
exampling mining kicks off. Note that we gradually reduce the mining
percent to the top_k_percent_pixels. For example, if
hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
< 1.0, only compute the loss for the top k percent pixels (e.g., the top
20% pixels). This is useful for hard pixel mining.
scope: String, the scope for the loss. scope: String, the scope for the loss.
Raises: Raises:
...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
scaled_labels = tf.reshape(scaled_labels, shape=[-1]) scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels, not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight ignore_label)) * loss_weight
one_hot_labels = slim.one_hot_encoding( one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0) scaled_labels, num_classes, on_value=1.0, off_value=0.0)
tf.losses.softmax_cross_entropy(
one_hot_labels, if top_k_percent_pixels == 1.0:
tf.reshape(logits, shape=[-1, num_classes]), # Compute the loss for all pixels.
weights=not_ignore_mask, tf.losses.softmax_cross_entropy(
scope=loss_scope) one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
num_pixels = tf.to_float(tf.shape(logits)[0])
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
total_loss = tf.reduce_sum(top_k_losses)
num_present = tf.reduce_sum(
tf.to_float(tf.not_equal(top_k_losses, 0.0)))
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
def get_model_init_fn(train_logdir, def get_model_init_fn(train_logdir,
...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir, ...@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir,
if not initialize_last_layer: if not initialize_last_layer:
exclude_list.extend(last_layers) exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list) variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore: if variables_to_restore:
return slim.assign_from_checkpoint_fn( init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
tf_initial_checkpoint, tf_initial_checkpoint,
variables_to_restore, variables_to_restore,
ignore_missing_vars=ignore_missing_vars) ignore_missing_vars=ignore_missing_vars)
global_step = tf.train.get_or_create_global_step()
def restore_fn(unused_scaffold, sess):
sess.run(init_op, init_feed_dict)
sess.run([global_step])
return restore_fn
return None return None
...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
""" """
gradient_multipliers = {} gradient_multipliers = {}
for var in slim.get_model_variables(): for var in tf.model_variables():
# Double the learning rate for biases. # Double the learning rate for biases.
if 'biases' in var.op.name: if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2. gradient_multipliers[var.op.name] = 2.
...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): ...@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
return gradient_multipliers return gradient_multipliers
def get_model_learning_rate( def get_model_learning_rate(learning_policy,
learning_policy, base_learning_rate, learning_rate_decay_step, base_learning_rate,
learning_rate_decay_factor, training_number_of_steps, learning_power, learning_rate_decay_step,
slow_start_step, slow_start_learning_rate): learning_rate_decay_factor,
training_number_of_steps,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
"""Gets model's learning rate. """Gets model's learning rate.
Computes the model's learning rate for different learning policy. Computes the model's learning rate for different learning policy.
...@@ -181,31 +245,51 @@ def get_model_learning_rate( ...@@ -181,31 +245,51 @@ def get_model_learning_rate(
slow_start_step: Training model with small learning rate for the first slow_start_step: Training model with small learning rate for the first
few steps. few steps.
slow_start_learning_rate: The learning rate employed during slow start. slow_start_learning_rate: The learning rate employed during slow start.
slow_start_burnin_type: The burnin type for the slow start stage. Can be
`none` which means no burnin or `linear` which means the learning rate
increases linearly from slow_start_learning_rate and reaches
base_learning_rate after slow_start_steps.
Returns: Returns:
Learning rate for the specified learning policy. Learning rate for the specified learning policy.
Raises: Raises:
ValueError: If learning policy is not recognized. ValueError: If learning policy or slow start burnin type is not recognized.
""" """
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
if learning_policy == 'step': if learning_policy == 'step':
learning_rate = tf.train.exponential_decay( learning_rate = tf.train.exponential_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
learning_rate_decay_step, learning_rate_decay_step,
learning_rate_decay_factor, learning_rate_decay_factor,
staircase=True) staircase=True)
elif learning_policy == 'poly': elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay( learning_rate = tf.train.polynomial_decay(
base_learning_rate, base_learning_rate,
global_step, adjusted_global_step,
training_number_of_steps, training_number_of_steps,
end_learning_rate=0, end_learning_rate=0,
power=learning_power) power=learning_power)
else: else:
raise ValueError('Unknown learning policy.') raise ValueError('Unknown learning policy.')
adjusted_slow_start_learning_rate = slow_start_learning_rate
if slow_start_burnin_type == 'linear':
# Do linear burnin. Increase linearly from slow_start_learning_rate and
# reach base_learning_rate after (global_step >= slow_start_steps).
adjusted_slow_start_learning_rate = (
slow_start_learning_rate +
(base_learning_rate - slow_start_learning_rate) *
tf.to_float(global_step) / slow_start_step)
elif slow_start_burnin_type != 'none':
raise ValueError('Unknown burnin type.')
# Employ small learning rate at the first few steps for warm start. # Employ small learning rate at the first few steps for warm start.
return tf.where(global_step < slow_start_step, slow_start_learning_rate, return tf.where(global_step < slow_start_step,
learning_rate) adjusted_slow_start_learning_rate, learning_rate)
...@@ -17,19 +17,15 @@ ...@@ -17,19 +17,15 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import math
import os.path import os.path
import time import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import segmentation_dataset from deeplab.datasets import data_generator
from deeplab.utils import input_generator
from deeplab.utils import save_annotation from deeplab.utils import save_annotation
slim = tf.contrib.slim
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names, ...@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names,
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information. # Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = data_generator.Dataset(
FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir) dataset_name=FLAGS.dataset,
split_name=FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.vis_batch_size,
crop_size=FLAGS.vis_crop_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
model_variant=FLAGS.model_variant,
is_training=False,
should_shuffle=False,
should_repeat=False)
train_id_to_eval_id = None train_id_to_eval_id = None
if dataset.name == segmentation_dataset.get_cityscapes_dataset_name(): if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
tf.logging.info('Cityscapes requires converting train_id to eval_id.') tf.logging.info('Cityscapes requires converting train_id to eval_id.')
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
...@@ -204,20 +213,11 @@ def main(unused_argv): ...@@ -204,20 +213,11 @@ def main(unused_argv):
tf.logging.info('Visualizing on %s set', FLAGS.vis_split) tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph() with tf.Graph().as_default():
with g.as_default(): samples = dataset.get_one_shot_iterator().get_next()
samples = input_generator.get(dataset,
FLAGS.vis_crop_size,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant)
model_options = common.ModelOptions( model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.vis_crop_size, crop_size=FLAGS.vis_crop_size,
atrous_rates=FLAGS.atrous_rates, atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride) output_stride=FLAGS.output_stride)
...@@ -244,7 +244,7 @@ def main(unused_argv): ...@@ -244,7 +244,7 @@ def main(unused_argv):
# Reverse the resizing and padding operations performed in preprocessing. # Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then # First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back. # we resize the predictions back.
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
original_image_shape = tf.shape(original_image) original_image_shape = tf.shape(original_image)
predictions = tf.slice( predictions = tf.slice(
...@@ -259,40 +259,35 @@ def main(unused_argv): ...@@ -259,40 +259,35 @@ def main(unused_argv):
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True), 3) align_corners=True), 3)
tf.train.get_or_create_global_step() num_iteration = 0
saver = tf.train.Saver(slim.get_variables_to_restore()) max_num_iteration = FLAGS.max_number_of_iterations
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir, checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
init_op=tf.global_variables_initializer(), FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
summary_op=None, for checkpoint_path in checkpoints_iterator:
summary_writer=None, if max_num_iteration > 0 and num_iteration > max_num_iteration:
global_step=None, break
saver=saver) num_iteration += 1
num_batches = int(math.ceil(
dataset.num_samples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Loop to visualize the results when new checkpoint is created.
num_iters = 0
while (FLAGS.max_number_of_iterations <= 0 or
num_iters < FLAGS.max_number_of_iterations):
num_iters += 1
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
tf.logging.info( tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint) tf.logging.info('Visualizing with model %s', checkpoint_path)
with sv.managed_session(FLAGS.master, tf.train.get_or_create_global_step()
start_standard_services=False) as sess:
sv.start_queue_runners(sess) scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
sv.saver.restore(sess, last_checkpoint) session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
master=FLAGS.master,
checkpoint_filename_with_path=checkpoint_path)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=None) as sess:
batch = 0
image_id_offset = 0 image_id_offset = 0
for batch in range(num_batches):
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) while not sess.should_stop():
tf.logging.info('Visualizing batch %d', batch + 1)
_process_batch(sess=sess, _process_batch(sess=sess,
original_images=samples[common.ORIGINAL_IMAGE], original_images=samples[common.ORIGINAL_IMAGE],
semantic_predictions=predictions, semantic_predictions=predictions,
...@@ -304,14 +299,11 @@ def main(unused_argv): ...@@ -304,14 +299,11 @@ def main(unused_argv):
raw_save_dir=raw_save_dir, raw_save_dir=raw_save_dir,
train_id_to_eval_id=train_id_to_eval_id) train_id_to_eval_id=train_id_to_eval_id)
image_id_offset += FLAGS.vis_batch_size image_id_offset += FLAGS.vis_batch_size
batch += 1
tf.logging.info( tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime())) time.gmtime()))
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir') flags.mark_flag_as_required('checkpoint_dir')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment