Commit 2c962110 authored by huihui-personal's avatar huihui-personal Committed by aquariusjay
Browse files

Open source deeplab changes. Check change log in README.md for detailed info. (#6231)

* \nRefactor deeplab to use MonitoredTrainingSession\n

PiperOrigin-RevId: 234237190

* Update export_model.py

* Update nas_cell.py

* Update nas_network.py

* Update train.py

* Update deeplab_demo.ipynb

* Update nas_cell.py
parent a432998c
......@@ -66,6 +66,11 @@ A: Our model uses whole-image inference, meaning that we need to set `eval_crop_
image dimension in the dataset. For example, we have `eval_crop_size` = 513x513 for PASCAL dataset whose largest image dimension is 512. Similarly, we set `eval_crop_size` = 1025x2049 for Cityscapes images whose
image dimension is all equal to 1024x2048.
___
Q9: Why multi-gpu training is slow?
A: Please try to use more threads to pre-process the inputs. For, example change [num_readers = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L71) and [num_threads = 4](https://github.com/tensorflow/models/blob/master/research/deeplab/utils/input_generator.py#L72).
___
## References
......
......@@ -32,13 +32,13 @@ sudo pip install matplotlib
## Add Libraries to PYTHONPATH
When running locally, the tensorflow/models/research/ and slim directories
should be appended to PYTHONPATH. This can be done by running the following from
When running locally, the tensorflow/models/research/ directory should be
appended to PYTHONPATH. This can be done by running the following from
tensorflow/models/research/:
```bash
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
export PYTHONPATH=$PYTHONPATH:`pwd`
```
Note: This command needs to run from every new terminal you start. If you wish
......
......@@ -88,13 +88,18 @@ We provide some checkpoints that have been pretrained on ADE20K training set.
Note that the model has only been pretrained on ImageNet, following the
dataset rule.
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----:
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4
Checkpoint name | Network backbone | Pretrained dataset | ASPP | Decoder | Input size
------------------------------------- | :--------------: | :-------------------------------------: | :----------------------------------------------: | :-----: | :-----:
mobilenetv2_ade20k_train | MobileNet-v2 | ImageNet <br> ADE20K training set | N/A | OS = 4 | 257x257
xception65_ade20k_train | Xception_65 | ImageNet <br> ADE20K training set | [6, 12, 18] for OS=16 <br> [12, 24, 36] for OS=8 | OS = 4 | 513x513
The input dimensions of ADE20K have a huge amount of variation. We resize inputs so that the longest size is 257 for MobileNet-v2 (faster inference) and 513 for Xception_65 (better performation). Note that we also include the decoder module in the MobileNet-v2 checkpoint.
Checkpoint name | Eval OS | Eval scales | Left-right Flip | mIOU | Pixel-wise Accuracy | File Size
------------------------------------- | :-------: | :-------------------------: | :-------------: | :-------------------: | :-------------------: | :-------:
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
[mobilenetv2_ade20k_train](http://download.tensorflow.org/models/deeplabv3_mnv2_ade20k_train_2018_12_03.tar.gz) | 16 | [1.0] | No | 32.04% (val) | 75.41% (val) | 24.8MB
[xception65_ade20k_train](http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz) | 8 | [0.5:0.25:1.75] | Yes | 45.65% (val) | 82.52% (val) | 439MB
## Checkpoints pretrained on ImageNet
......
......@@ -82,7 +82,7 @@ def preprocess_image_and_label(image,
label = tf.cast(label, tf.int32)
# Resize image and label to the desired range.
if min_resize_value is not None or max_resize_value is not None:
if min_resize_value or max_resize_value:
[processed_image, label] = (
preprocess_utils.resize_to_range(
image=processed_image,
......
This diff is collapsed.
......@@ -87,6 +87,7 @@ class DeeplabModelTest(tf.test.TestCase):
add_image_level_feature=True,
aspp_with_batch_norm=True,
logits_kernel_size=1,
decoder_output_stride=[4],
model_variant='mobilenet_v2') # Employ MobileNetv2 for fast test.
g = tf.Graph()
......@@ -116,16 +117,16 @@ class DeeplabModelTest(tf.test.TestCase):
outputs_to_num_classes = {'semantic': 2}
expected_endpoints = ['merged_logits']
dense_prediction_cell_config = [
{'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1},
{'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0},
{'kernel': 3, 'rate': [1, 6], 'op': 'conv', 'input': -1},
{'kernel': 3, 'rate': [18, 15], 'op': 'conv', 'input': 0},
]
model_options = common.ModelOptions(
outputs_to_num_classes,
crop_size,
output_stride=16)._replace(
aspp_with_batch_norm=True,
model_variant='mobilenet_v2',
dense_prediction_cell_config=dense_prediction_cell_config)
aspp_with_batch_norm=True,
model_variant='mobilenet_v2',
dense_prediction_cell_config=dense_prediction_cell_config)
g = tf.Graph()
with g.as_default():
with self.test_session(graph=g):
......@@ -137,8 +138,8 @@ class DeeplabModelTest(tf.test.TestCase):
image_pyramid=[1.0])
for output in outputs_to_num_classes:
scales_to_model_results = outputs_to_scales_to_model_results[output]
self.assertListEqual(scales_to_model_results.keys(),
expected_endpoints)
self.assertListEqual(
list(scales_to_model_results), expected_endpoints)
self.assertEqual(len(scales_to_model_results), 1)
......
This directory contains testing data.
# pascal_voc_seg
This folder contains data specific to pascal_voc_seg dataset. val-00000-of-00001.tfrecord contains
three randomly generated images with format defined in
tensorflow/models/research/deeplab/datasets/build_voc2012_data.py.
This diff is collapsed.
......@@ -37,7 +37,7 @@ _PASCAL = 'pascal'
# Max number of entries in the colormap for each dataset.
_DATASET_MAX_ENTRIES = {
_ADE20K: 151,
_CITYSCAPES: 19,
_CITYSCAPES: 256,
_MAPILLARY_VISTAS: 66,
_PASCAL: 256,
}
......@@ -210,27 +210,27 @@ def create_cityscapes_label_colormap():
Returns:
A colormap for visualizing segmentation results.
"""
return np.asarray([
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
])
colormap = np.zeros((256, 3), dtype=np.uint8)
colormap[0] = [128, 64, 128]
colormap[1] = [244, 35, 232]
colormap[2] = [70, 70, 70]
colormap[3] = [102, 102, 156]
colormap[4] = [190, 153, 153]
colormap[5] = [153, 153, 153]
colormap[6] = [250, 170, 30]
colormap[7] = [220, 220, 0]
colormap[8] = [107, 142, 35]
colormap[9] = [152, 251, 152]
colormap[10] = [70, 130, 180]
colormap[11] = [220, 20, 60]
colormap[12] = [255, 0, 0]
colormap[13] = [0, 0, 142]
colormap[14] = [0, 0, 70]
colormap[15] = [0, 60, 100]
colormap[16] = [0, 80, 100]
colormap[17] = [0, 0, 230]
colormap[18] = [119, 11, 32]
return colormap
def create_mapillary_vistas_label_colormap():
......@@ -396,10 +396,16 @@ def label_to_color_image(label, dataset=_PASCAL):
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
raise ValueError('label value too large.')
raise ValueError(
'label value too large: {} >= {}.'.format(
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
colormap = create_label_colormap(dataset)
return colormap[label]
def get_dataset_colormap_max_entries(dataset):
return _DATASET_MAX_ENTRIES[dataset]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for providing semantic segmentation data."""
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess
slim = tf.contrib.slim
dataset_data_provider = slim.dataset_data_provider
def _get_data(data_provider, dataset_split):
"""Gets data from data provider.
Args:
data_provider: An object of slim.data_provider.
dataset_split: Dataset split.
Returns:
image: Image Tensor.
label: Label Tensor storing segmentation annotations.
image_name: Image name.
height: Image height.
width: Image width.
Raises:
ValueError: Failed to find label.
"""
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
"""Gets the dataset split for semantic segmentation.
This functions gets the dataset split for semantic segmentation. In
particular, it is a wrapper of (1) dataset_data_provider which returns the raw
dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
Tensorflow operation of batching the preprocessed data. Then, the output could
be directly used by training, evaluation or visualization.
Args:
dataset: An instance of slim Dataset.
crop_size: Image crop size [height, width].
batch_size: Batch size.
min_resize_value: Desired size of the smaller image side.
max_resize_value: Maximum allowed size of the larger image side.
resize_factor: Resized dimensions are multiple of factor plus one.
min_scale_factor: Minimum scale factor value.
max_scale_factor: Maximum scale factor value.
scale_factor_step_size: The step size from min scale factor to max scale
factor. The input is randomly scaled based on the value of
(min_scale_factor, max_scale_factor, scale_factor_step_size).
num_readers: Number of readers for data provider.
num_threads: Number of threads for batching data.
dataset_split: Dataset split.
is_training: Is training or not.
model_variant: Model variant (string) for choosing how to mean-subtract the
images. See feature_extractor.network_map for supported model variants.
Returns:
A dictionary of batched Tensors for semantic segmentation.
Raises:
ValueError: dataset_split is None, failed to find labels, or label shape
is not valid.
"""
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
......@@ -29,16 +29,20 @@ def save_annotation(label,
save_dir,
filename,
add_colormap=True,
normalize_to_unit_values=False,
scale_values=False,
colormap_type=get_dataset_colormap.get_pascal_name()):
"""Saves the given label to image on disk.
Args:
label: The numpy array to be saved. The data will be converted
to uint8 and saved as png image.
save_dir: The directory to which the results will be saved.
filename: The image filename.
add_colormap: Add color map to the label or not.
colormap_type: Colormap type for visualization.
save_dir: String, the directory to which the results will be saved.
filename: String, the image filename.
add_colormap: Boolean, add color map to the label or not.
normalize_to_unit_values: Boolean, normalize the input values to [0, 1].
scale_values: Boolean, scale the input values to [0, 255] for visualization.
colormap_type: String, colormap type for visualization.
"""
# Add colormap for visualizing the prediction.
if add_colormap:
......@@ -46,6 +50,15 @@ def save_annotation(label,
label, colormap_type)
else:
colored_label = label
if normalize_to_unit_values:
min_value = np.amin(colored_label)
max_value = np.amax(colored_label)
range_value = max_value - min_value
if range_value != 0:
colored_label = (colored_label - min_value) / range_value
if scale_values:
colored_label = 255. * colored_label
pil_image = img.fromarray(colored_label.astype(dtype=np.uint8))
with tf.gfile.Open('%s/%s.png' % (save_dir, filename), mode='w') as f:
......
......@@ -19,7 +19,11 @@ import six
import tensorflow as tf
from deeplab.core import preprocess_utils
slim = tf.contrib.slim
def _div_maybe_zero(total_loss, num_present):
"""Normalizes the total loss with the number of present pixels."""
return tf.to_float(num_present > 0) * tf.div(total_loss,
tf.maximum(1e-5, num_present))
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
......@@ -28,6 +32,8 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label,
loss_weight=1.0,
upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
scope=None):
"""Adds softmax cross entropy loss for logits of each scale.
......@@ -39,6 +45,15 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
ignore_label: Integer, label to ignore.
loss_weight: Float, loss weight.
upsample_logits: Boolean, upsample logits or not.
hard_example_mining_step: An integer, the training step in which the hard
exampling mining kicks off. Note that we gradually reduce the mining
percent to the top_k_percent_pixels. For example, if
hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
mining percent will gradually reduce from 100% to 25% until 100K steps
after which we only mine top 25% pixels.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
< 1.0, only compute the loss for the top k percent pixels (e.g., the top
20% pixels). This is useful for hard pixel mining.
scope: String, the scope for the loss.
Raises:
......@@ -69,13 +84,48 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight
one_hot_labels = slim.one_hot_encoding(
one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0)
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
if top_k_percent_pixels == 1.0:
# Compute the loss for all pixels.
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
num_pixels = tf.to_float(tf.shape(logits)[0])
# Compute the top_k_percent pixels based on current training step.
if hard_example_mining_step == 0:
# Directly focus on the top_k pixels.
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
# Gradually reduce the mining percent to top_k_percent_pixels.
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
total_loss = tf.reduce_sum(top_k_losses)
num_present = tf.reduce_sum(
tf.to_float(tf.not_equal(top_k_losses, 0.0)))
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
def get_model_init_fn(train_logdir,
......@@ -110,13 +160,22 @@ def get_model_init_fn(train_logdir,
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore:
return slim.assign_from_checkpoint_fn(
init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
global_step = tf.train.get_or_create_global_step()
def restore_fn(unused_scaffold, sess):
sess.run(init_op, init_feed_dict)
sess.run([global_step])
return restore_fn
return None
......@@ -138,7 +197,7 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
"""
gradient_multipliers = {}
for var in slim.get_model_variables():
for var in tf.model_variables():
# Double the learning rate for biases.
if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2.
......@@ -155,10 +214,15 @@ def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
return gradient_multipliers
def get_model_learning_rate(
learning_policy, base_learning_rate, learning_rate_decay_step,
learning_rate_decay_factor, training_number_of_steps, learning_power,
slow_start_step, slow_start_learning_rate):
def get_model_learning_rate(learning_policy,
base_learning_rate,
learning_rate_decay_step,
learning_rate_decay_factor,
training_number_of_steps,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
"""Gets model's learning rate.
Computes the model's learning rate for different learning policy.
......@@ -181,31 +245,51 @@ def get_model_learning_rate(
slow_start_step: Training model with small learning rate for the first
few steps.
slow_start_learning_rate: The learning rate employed during slow start.
slow_start_burnin_type: The burnin type for the slow start stage. Can be
`none` which means no burnin or `linear` which means the learning rate
increases linearly from slow_start_learning_rate and reaches
base_learning_rate after slow_start_steps.
Returns:
Learning rate for the specified learning policy.
Raises:
ValueError: If learning policy is not recognized.
ValueError: If learning policy or slow start burnin type is not recognized.
"""
global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
if learning_policy == 'step':
learning_rate = tf.train.exponential_decay(
base_learning_rate,
global_step,
adjusted_global_step,
learning_rate_decay_step,
learning_rate_decay_factor,
staircase=True)
elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay(
base_learning_rate,
global_step,
adjusted_global_step,
training_number_of_steps,
end_learning_rate=0,
power=learning_power)
else:
raise ValueError('Unknown learning policy.')
adjusted_slow_start_learning_rate = slow_start_learning_rate
if slow_start_burnin_type == 'linear':
# Do linear burnin. Increase linearly from slow_start_learning_rate and
# reach base_learning_rate after (global_step >= slow_start_steps).
adjusted_slow_start_learning_rate = (
slow_start_learning_rate +
(base_learning_rate - slow_start_learning_rate) *
tf.to_float(global_step) / slow_start_step)
elif slow_start_burnin_type != 'none':
raise ValueError('Unknown burnin type.')
# Employ small learning rate at the first few steps for warm start.
return tf.where(global_step < slow_start_step, slow_start_learning_rate,
learning_rate)
return tf.where(global_step < slow_start_step,
adjusted_slow_start_learning_rate, learning_rate)
......@@ -17,19 +17,15 @@
See model.py for more details and usage.
"""
import math
import os.path
import time
import numpy as np
import tensorflow as tf
from deeplab import common
from deeplab import model
from deeplab.datasets import segmentation_dataset
from deeplab.utils import input_generator
from deeplab.datasets import data_generator
from deeplab.utils import save_annotation
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -186,11 +182,24 @@ def _process_batch(sess, original_images, semantic_predictions, image_names,
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir)
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.vis_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.vis_batch_size,
crop_size=FLAGS.vis_crop_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
model_variant=FLAGS.model_variant,
is_training=False,
should_shuffle=False,
should_repeat=False)
train_id_to_eval_id = None
if dataset.name == segmentation_dataset.get_cityscapes_dataset_name():
if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
tf.logging.info('Cityscapes requires converting train_id to eval_id.')
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID
......@@ -204,20 +213,11 @@ def main(unused_argv):
tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
g = tf.Graph()
with g.as_default():
samples = input_generator.get(dataset,
FLAGS.vis_crop_size,
FLAGS.vis_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
dataset_split=FLAGS.vis_split,
is_training=False,
model_variant=FLAGS.model_variant)
with tf.Graph().as_default():
samples = dataset.get_one_shot_iterator().get_next()
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.vis_crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
......@@ -244,7 +244,7 @@ def main(unused_argv):
# Reverse the resizing and padding operations performed in preprocessing.
# First, we slice the valid regions (i.e., remove padded region) and then
# we reisze the predictions back.
# we resize the predictions back.
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
original_image_shape = tf.shape(original_image)
predictions = tf.slice(
......@@ -259,40 +259,35 @@ def main(unused_argv):
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True), 3)
tf.train.get_or_create_global_step()
saver = tf.train.Saver(slim.get_variables_to_restore())
sv = tf.train.Supervisor(graph=g,
logdir=FLAGS.vis_logdir,
init_op=tf.global_variables_initializer(),
summary_op=None,
summary_writer=None,
global_step=None,
saver=saver)
num_batches = int(math.ceil(
dataset.num_samples / float(FLAGS.vis_batch_size)))
last_checkpoint = None
# Loop to visualize the results when new checkpoint is created.
num_iters = 0
while (FLAGS.max_number_of_iterations <= 0 or
num_iters < FLAGS.max_number_of_iterations):
num_iters += 1
last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
FLAGS.checkpoint_dir, last_checkpoint)
start = time.time()
num_iteration = 0
max_num_iteration = FLAGS.max_number_of_iterations
checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
for checkpoint_path in checkpoints_iterator:
if max_num_iteration > 0 and num_iteration > max_num_iteration:
break
num_iteration += 1
tf.logging.info(
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
tf.logging.info('Visualizing with model %s', last_checkpoint)
with sv.managed_session(FLAGS.master,
start_standard_services=False) as sess:
sv.start_queue_runners(sess)
sv.saver.restore(sess, last_checkpoint)
tf.logging.info('Visualizing with model %s', checkpoint_path)
tf.train.get_or_create_global_step()
scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
master=FLAGS.master,
checkpoint_filename_with_path=checkpoint_path)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=None) as sess:
batch = 0
image_id_offset = 0
for batch in range(num_batches):
tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
while not sess.should_stop():
tf.logging.info('Visualizing batch %d', batch + 1)
_process_batch(sess=sess,
original_images=samples[common.ORIGINAL_IMAGE],
semantic_predictions=predictions,
......@@ -304,14 +299,11 @@ def main(unused_argv):
raw_save_dir=raw_save_dir,
train_id_to_eval_id=train_id_to_eval_id)
image_id_offset += FLAGS.vis_batch_size
batch += 1
tf.logging.info(
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_dir')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment