"src/vscode:/vscode.git/clone" did not exist on "87252d80c3ea8eb6fba8b6de8c2dac9ede4fadee"
Commit 26c96542 authored by DefineFC's avatar DefineFC
Browse files

prepare data, train, and eval on py3

parent 65f4d60b
......@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split):
i + 1, num_images, shard_id))
sys.stdout.flush()
# Read the image.
image_data = tf.gfile.FastGFile(image_files[i], 'r').read()
image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_data = tf.gfile.FastGFile(label_files[i], 'r').read()
seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.')
......
......@@ -125,7 +125,10 @@ def _bytes_list_feature(values):
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def norm2bytes(value):
return value.encode() if isinstance(value, str) else value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
......
......@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split):
# Read the image.
image_filename = os.path.join(
FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
image_data = tf.gfile.FastGFile(image_filename, 'r').read()
image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_filename = os.path.join(
FLAGS.semantic_segmentation_folder,
filenames[i] + '.' + FLAGS.label_format)
seg_data = tf.gfile.FastGFile(seg_filename, 'r').read()
seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.')
......
......@@ -17,6 +17,7 @@
See model.py for more details and usage.
"""
import six
import math
import tensorflow as tf
from deeplab import common
......@@ -144,7 +145,7 @@ def main(unused_argv):
metrics_to_values, metrics_to_updates = (
tf.contrib.metrics.aggregate_metric_map(metric_map))
for metric_name, metric_value in metrics_to_values.iteritems():
for metric_name, metric_value in six.iteritems(metrics_to_values):
slim.summaries.add_scalar_summary(
metric_value, metric_name, print_summary=True)
......@@ -163,7 +164,7 @@ def main(unused_argv):
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
eval_op=list(metrics_to_updates.values()),
max_number_of_evaluations=num_eval_iters,
eval_interval_secs=FLAGS.eval_interval_secs)
......
......@@ -17,6 +17,7 @@
See model.py for more details and usage.
"""
import six
import tensorflow as tf
from deeplab import common
from deeplab import model
......@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
for output, num_classes in outputs_to_num_classes.iteritems():
for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output],
samples[common.LABEL],
......@@ -217,7 +218,7 @@ def main(unused_argv):
assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size / config.num_clones
clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)
# Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset(
......
......@@ -14,6 +14,8 @@
# ==============================================================================
"""Utility functions for training."""
import six
import tensorflow as tf
slim = tf.contrib.slim
......@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
for scale, logits in scales_to_logits.iteritems():
for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment