Commit 1498d941 authored by Yukun Zhu's avatar Yukun Zhu Committed by aquariusjay
Browse files

Update for py3 and some internal changes (#7786)

parent 42c3b8f0
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved. # Copyright 2019 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved. # Copyright 2019 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -26,6 +27,7 @@ from absl import flags ...@@ -26,6 +27,7 @@ from absl import flags
import numpy as np import numpy as np
import scipy.misc import scipy.misc
import six import six
from six.moves import map
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved. # Copyright 2019 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,13 +17,15 @@ ...@@ -16,13 +17,15 @@
import os import os
import tensorflow as tf import tensorflow as tf
from google3.learning.brain.contrib import quantize as contrib_quantize
from google3.learning.brain.contrib import slim as contrib_slim
from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import freeze_graph
from deeplab import common from deeplab import common
from deeplab import input_preprocess from deeplab import input_preprocess
from deeplab import model from deeplab import model
slim = tf.contrib.slim slim = contrib_slim
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -63,10 +66,14 @@ flags.DEFINE_bool('save_inference_graph', False, ...@@ -63,10 +66,14 @@ flags.DEFINE_bool('save_inference_graph', False,
# Input name of the exported model. # Input name of the exported model.
_INPUT_NAME = 'ImageTensor' _INPUT_NAME = 'ImageTensor'
# Output name of the exported model. # Output name of the exported predictions.
_OUTPUT_NAME = 'SemanticPredictions' _OUTPUT_NAME = 'SemanticPredictions'
_RAW_OUTPUT_NAME = 'RawSemanticPredictions' _RAW_OUTPUT_NAME = 'RawSemanticPredictions'
# Output name of the exported probabilities.
_OUTPUT_PROB_NAME = 'SemanticProbabilities'
_RAW_OUTPUT_PROB_NAME = 'RawSemanticProbabilities'
def _create_input_tensors(): def _create_input_tensors():
"""Creates and prepares input tensors for DeepLab model. """Creates and prepares input tensors for DeepLab model.
...@@ -139,11 +146,16 @@ def main(unused_argv): ...@@ -139,11 +146,16 @@ def main(unused_argv):
raw_predictions = tf.identity( raw_predictions = tf.identity(
tf.cast(predictions[common.OUTPUT_TYPE], tf.float32), tf.cast(predictions[common.OUTPUT_TYPE], tf.float32),
_RAW_OUTPUT_NAME) _RAW_OUTPUT_NAME)
raw_probabilities = tf.identity(
predictions[common.OUTPUT_TYPE + model.PROB_SUFFIX],
_RAW_OUTPUT_PROB_NAME)
# Crop the valid regions from the predictions. # Crop the valid regions from the predictions.
semantic_predictions = tf.slice( semantic_predictions = raw_predictions[
raw_predictions, :, :resized_image_size[0], :resized_image_size[1]]
[0, 0, 0], semantic_probabilities = raw_probabilities[
[1, resized_image_size[0], resized_image_size[1]]) :, :resized_image_size[0], :resized_image_size[1]]
# Resize back the prediction to the original image size. # Resize back the prediction to the original image size.
def _resize_label(label, label_size): def _resize_label(label, label_size):
# Expand dimension of label to [1, height, width, 1] for resize operation. # Expand dimension of label to [1, height, width, 1] for resize operation.
...@@ -157,8 +169,12 @@ def main(unused_argv): ...@@ -157,8 +169,12 @@ def main(unused_argv):
semantic_predictions = _resize_label(semantic_predictions, image_size) semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME) semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
semantic_probabilities = tf.image.resize_bilinear(
semantic_probabilities, image_size, align_corners=True,
name=_OUTPUT_PROB_NAME)
if FLAGS.quantize_delay_step >= 0: if FLAGS.quantize_delay_step >= 0:
tf.contrib.quantize.create_eval_graph() contrib_quantize.create_eval_graph()
saver = tf.train.Saver(tf.all_variables()) saver = tf.train.Saver(tf.all_variables())
...@@ -169,7 +185,7 @@ def main(unused_argv): ...@@ -169,7 +185,7 @@ def main(unused_argv):
graph_def, graph_def,
saver.as_saver_def(), saver.as_saver_def(),
FLAGS.checkpoint_path, FLAGS.checkpoint_path,
_OUTPUT_NAME, _OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME,
restore_op_name=None, restore_op_name=None,
filename_tensor_name=None, filename_tensor_name=None,
output_graph=FLAGS.export_path, output_graph=FLAGS.export_path,
......
...@@ -98,12 +98,12 @@ For quantized (8bit) model, un-tar'ed directory includes: ...@@ -98,12 +98,12 @@ For quantized (8bit) model, un-tar'ed directory includes:
* a converted TFlite FlatBuffer file (frozen_inference_graph.tflite) * a converted TFlite FlatBuffer file (frozen_inference_graph.tflite)
Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Quantize | PASCAL mIOU | File Size Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Quantize | PASCAL mIOU | Folder Size | TFLite File Size
-------------------------------------------------------------------------------------------------------------------------------------------- | :-----: | :---------: | :-------------: | :-----------: | :------: | :----------: | :-------: -------------------------------------------------------------------------------------------------------------------------------------------- | :-----: | :---------: | :-------------: | :-----------: | :------: | :----------: | :-------: | :-------:
[mobilenetv2_dm05_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainaug_2018_10_01.tar.gz) | 16 | [1.0] | No | 0.88B | No | 70.19% (val) | 7.6MB [mobilenetv2_dm05_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainaug_2018_10_01.tar.gz) | 16 | [1.0] | No | 0.88B | No | 70.19% (val) | 7.6MB | N/A
[mobilenetv2_dm05_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 0.88B | Yes | 69.65% (val) | 8.2MB [mobilenetv2_dm05_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 0.88B | Yes | 69.65% (val) | 8.2MB | 751.1KB
[mobilenetv2_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz) | 16 | [1.0] | No | 2.75B | No | 75.32% (val) | 23MB [mobilenetv2_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz) | 16 | [1.0] | No | 2.75B | No | 75.32% (val) | 23MB | N/A
[mobilenetv2_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 2.75B | Yes | 74.26% (val) | 24MB [mobilenetv2_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 2.75B | Yes | 74.26% (val) | 24MB | 2.2MB
Note that you might need the nightly build of TensorFlow (see Note that you might need the nightly build of TensorFlow (see
[here](https://www.tensorflow.org/install) for install instructions) to convert [here](https://www.tensorflow.org/install) for install instructions) to convert
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -26,7 +27,11 @@ defined by the different datasets. Supported colormaps are: ...@@ -26,7 +27,11 @@ defined by the different datasets. Supported colormaps are:
* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/). * PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range
# Dataset names. # Dataset names.
_ADE20K = 'ade20k' _ADE20K = 'ade20k'
...@@ -39,7 +44,7 @@ _DATASET_MAX_ENTRIES = { ...@@ -39,7 +44,7 @@ _DATASET_MAX_ENTRIES = {
_ADE20K: 151, _ADE20K: 151,
_CITYSCAPES: 256, _CITYSCAPES: 256,
_MAPILLARY_VISTAS: 66, _MAPILLARY_VISTAS: 66,
_PASCAL: 256, _PASCAL: 512,
} }
...@@ -318,7 +323,7 @@ def create_pascal_label_colormap(): ...@@ -318,7 +323,7 @@ def create_pascal_label_colormap():
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int) colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int) ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
for shift in reversed(range(8)): for shift in reversed(list(range(8))):
for channel in range(3): for channel in range(3):
colormap[:, channel] |= bit_get(ind, channel) << shift colormap[:, channel] |= bit_get(ind, channel) << shift
ind >>= 3 ind >>= 3
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -54,7 +55,7 @@ class VisualizationUtilTest(tf.test.TestCase): ...@@ -54,7 +55,7 @@ class VisualizationUtilTest(tf.test.TestCase):
def testUnExpectedLabelValueForLabelToPASCALColorImage(self): def testUnExpectedLabelValueForLabelToPASCALColorImage(self):
"""Raise ValueError when input value exceeds range.""" """Raise ValueError when input value exceeds range."""
label = np.array([[120], [300]]) label = np.array([[120], [600]])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_dataset_colormap.label_to_color_image( get_dataset_colormap.label_to_color_image(
label, get_dataset_colormap.get_pascal_name()) label, get_dataset_colormap.get_pascal_name())
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,10 +18,16 @@ ...@@ -17,10 +18,16 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path import os.path
import time import time
import numpy as np import numpy as np
from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import training as contrib_training
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
from deeplab.datasets import data_generator from deeplab.datasets import data_generator
...@@ -80,7 +87,7 @@ flags.DEFINE_string('vis_split', 'val', ...@@ -80,7 +87,7 @@ flags.DEFINE_string('vis_split', 'val',
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.') flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes'], flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes', 'ade20k'],
'Visualization colormap type.') 'Visualization colormap type.')
flags.DEFINE_boolean('also_save_raw_predictions', False, flags.DEFINE_boolean('also_save_raw_predictions', False,
...@@ -268,12 +275,12 @@ def main(unused_argv): ...@@ -268,12 +275,12 @@ def main(unused_argv):
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
if FLAGS.quantize_delay_step >= 0: if FLAGS.quantize_delay_step >= 0:
tf.contrib.quantize.create_eval_graph() contrib_quantize.create_eval_graph()
num_iteration = 0 num_iteration = 0
max_num_iteration = FLAGS.max_number_of_iterations max_num_iteration = FLAGS.max_number_of_iterations
checkpoints_iterator = tf.contrib.training.checkpoints_iterator( checkpoints_iterator = contrib_training.checkpoints_iterator(
FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs) FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
for checkpoint_path in checkpoints_iterator: for checkpoint_path in checkpoints_iterator:
num_iteration += 1 num_iteration += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment