Unverified Commit 2041d5ca authored by Yukun Zhu's avatar Yukun Zhu Committed by GitHub
Browse files

Merge pull request #3852 from DefineFC/deeplab-py3

Deeplab py3
parents 047bcef3 279aa927
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for xception.py.""" """Tests for xception.py."""
import six
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -309,7 +310,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -309,7 +310,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 14, 14, 4], 'xception/middle_flow/block1': [2, 14, 14, 4],
'xception/exit_flow/block1': [2, 7, 7, 8], 'xception/exit_flow/block1': [2, 7, 7, 8],
'xception/exit_flow/block2': [2, 7, 7, 16]} 'xception/exit_flow/block2': [2, 7, 7, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self): def testFullyConvolutionalEndpointShapes(self):
...@@ -330,7 +331,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -330,7 +331,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 21, 21, 4], 'xception/middle_flow/block1': [2, 21, 21, 4],
'xception/exit_flow/block1': [2, 11, 11, 8], 'xception/exit_flow/block1': [2, 11, 11, 8],
'xception/exit_flow/block2': [2, 11, 11, 16]} 'xception/exit_flow/block2': [2, 11, 11, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self): def testAtrousFullyConvolutionalEndpointShapes(self):
...@@ -352,7 +353,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -352,7 +353,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 41, 41, 4], 'xception/middle_flow/block1': [2, 41, 41, 4],
'xception/exit_flow/block1': [2, 41, 41, 8], 'xception/exit_flow/block1': [2, 41, 41, 8],
'xception/exit_flow/block2': [2, 41, 41, 16]} 'xception/exit_flow/block2': [2, 41, 41, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self): def testAtrousFullyConvolutionalValues(self):
......
...@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split): ...@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split):
i + 1, num_images, shard_id)) i + 1, num_images, shard_id))
sys.stdout.flush() sys.stdout.flush()
# Read the image. # Read the image.
image_data = tf.gfile.FastGFile(image_files[i], 'r').read() image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
height, width = image_reader.read_image_dims(image_data) height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation. # Read the semantic segmentation annotation.
seg_data = tf.gfile.FastGFile(label_files[i], 'r').read() seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data) seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width: if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.') raise RuntimeError('Shape mismatched between image and label.')
......
...@@ -125,7 +125,10 @@ def _bytes_list_feature(values): ...@@ -125,7 +125,10 @@ def _bytes_list_feature(values):
Returns: Returns:
A TF-Feature. A TF-Feature.
""" """
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def norm2bytes(value):
return value.encode() if isinstance(value, str) else value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
def image_seg_to_tfexample(image_data, filename, height, width, seg_data): def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
......
...@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split): ...@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split):
# Read the image. # Read the image.
image_filename = os.path.join( image_filename = os.path.join(
FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format) FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
image_data = tf.gfile.FastGFile(image_filename, 'r').read() image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
height, width = image_reader.read_image_dims(image_data) height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation. # Read the semantic segmentation annotation.
seg_filename = os.path.join( seg_filename = os.path.join(
FLAGS.semantic_segmentation_folder, FLAGS.semantic_segmentation_folder,
filenames[i] + '.' + FLAGS.label_format) filenames[i] + '.' + FLAGS.label_format)
seg_data = tf.gfile.FastGFile(seg_filename, 'r').read() seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data) seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width: if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.') raise RuntimeError('Shape mismatched between image and label.')
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import six
import math import math
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
...@@ -144,7 +145,7 @@ def main(unused_argv): ...@@ -144,7 +145,7 @@ def main(unused_argv):
metrics_to_values, metrics_to_updates = ( metrics_to_values, metrics_to_updates = (
tf.contrib.metrics.aggregate_metric_map(metric_map)) tf.contrib.metrics.aggregate_metric_map(metric_map))
for metric_name, metric_value in metrics_to_values.iteritems(): for metric_name, metric_value in six.iteritems(metrics_to_values):
slim.summaries.add_scalar_summary( slim.summaries.add_scalar_summary(
metric_value, metric_name, print_summary=True) metric_value, metric_name, print_summary=True)
...@@ -163,7 +164,7 @@ def main(unused_argv): ...@@ -163,7 +164,7 @@ def main(unused_argv):
checkpoint_dir=FLAGS.checkpoint_dir, checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir, logdir=FLAGS.eval_logdir,
num_evals=num_batches, num_evals=num_batches,
eval_op=metrics_to_updates.values(), eval_op=list(metrics_to_updates.values()),
max_number_of_evaluations=num_eval_iters, max_number_of_evaluations=num_eval_iters,
eval_interval_secs=FLAGS.eval_interval_secs) eval_interval_secs=FLAGS.eval_interval_secs)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import six
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
...@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label): ...@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
is_training=True, is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
for output, num_classes in outputs_to_num_classes.iteritems(): for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale( train_utils.add_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output], outputs_to_scales_to_logits[output],
samples[common.LABEL], samples[common.LABEL],
...@@ -217,7 +218,7 @@ def main(unused_argv): ...@@ -217,7 +218,7 @@ def main(unused_argv):
assert FLAGS.train_batch_size % config.num_clones == 0, ( assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).') 'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size / config.num_clones clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)
# Get dataset-dependent information. # Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = segmentation_dataset.get_dataset(
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Utility functions for training.""" """Utility functions for training."""
import six
import tensorflow as tf import tensorflow as tf
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
if labels is None: if labels is None:
raise ValueError('No label for softmax cross entropy loss.') raise ValueError('No label for softmax cross entropy loss.')
for scale, logits in scales_to_logits.iteritems(): for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None loss_scope = None
if scope: if scope:
loss_scope = '%s_%s' % (scope, scale) loss_scope = '%s_%s' % (scope, scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment