"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "62dd95870c812d87418e53229eb3fdee95c8a067"
Commit 26c96542 authored by DefineFC's avatar DefineFC
Browse files

prepare data, train, and eval on py3

parent 65f4d60b
...@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split): ...@@ -154,10 +154,10 @@ def _convert_dataset(dataset_split):
i + 1, num_images, shard_id)) i + 1, num_images, shard_id))
sys.stdout.flush() sys.stdout.flush()
# Read the image. # Read the image.
image_data = tf.gfile.FastGFile(image_files[i], 'r').read() image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
height, width = image_reader.read_image_dims(image_data) height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation. # Read the semantic segmentation annotation.
seg_data = tf.gfile.FastGFile(label_files[i], 'r').read() seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data) seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width: if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.') raise RuntimeError('Shape mismatched between image and label.')
......
...@@ -125,7 +125,10 @@ def _bytes_list_feature(values): ...@@ -125,7 +125,10 @@ def _bytes_list_feature(values):
Returns: Returns:
A TF-Feature. A TF-Feature.
""" """
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def norm2bytes(value):
return value.encode() if isinstance(value, str) else value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
def image_seg_to_tfexample(image_data, filename, height, width, seg_data): def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
......
...@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split): ...@@ -114,13 +114,13 @@ def _convert_dataset(dataset_split):
# Read the image. # Read the image.
image_filename = os.path.join( image_filename = os.path.join(
FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format) FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
image_data = tf.gfile.FastGFile(image_filename, 'r').read() image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
height, width = image_reader.read_image_dims(image_data) height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation. # Read the semantic segmentation annotation.
seg_filename = os.path.join( seg_filename = os.path.join(
FLAGS.semantic_segmentation_folder, FLAGS.semantic_segmentation_folder,
filenames[i] + '.' + FLAGS.label_format) filenames[i] + '.' + FLAGS.label_format)
seg_data = tf.gfile.FastGFile(seg_filename, 'r').read() seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data) seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width: if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.') raise RuntimeError('Shape mismatched between image and label.')
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import six
import math import math
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
...@@ -144,7 +145,7 @@ def main(unused_argv): ...@@ -144,7 +145,7 @@ def main(unused_argv):
metrics_to_values, metrics_to_updates = ( metrics_to_values, metrics_to_updates = (
tf.contrib.metrics.aggregate_metric_map(metric_map)) tf.contrib.metrics.aggregate_metric_map(metric_map))
for metric_name, metric_value in metrics_to_values.iteritems(): for metric_name, metric_value in six.iteritems(metrics_to_values):
slim.summaries.add_scalar_summary( slim.summaries.add_scalar_summary(
metric_value, metric_name, print_summary=True) metric_value, metric_name, print_summary=True)
...@@ -163,7 +164,7 @@ def main(unused_argv): ...@@ -163,7 +164,7 @@ def main(unused_argv):
checkpoint_dir=FLAGS.checkpoint_dir, checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_logdir, logdir=FLAGS.eval_logdir,
num_evals=num_batches, num_evals=num_batches,
eval_op=metrics_to_updates.values(), eval_op=list(metrics_to_updates.values()),
max_number_of_evaluations=num_eval_iters, max_number_of_evaluations=num_eval_iters,
eval_interval_secs=FLAGS.eval_interval_secs) eval_interval_secs=FLAGS.eval_interval_secs)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
See model.py for more details and usage. See model.py for more details and usage.
""" """
import six
import tensorflow as tf import tensorflow as tf
from deeplab import common from deeplab import common
from deeplab import model from deeplab import model
...@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label): ...@@ -190,7 +191,7 @@ def _build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label):
is_training=True, is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
for output, num_classes in outputs_to_num_classes.iteritems(): for output, num_classes in six.iteritems(outputs_to_num_classes):
train_utils.add_softmax_cross_entropy_loss_for_each_scale( train_utils.add_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output], outputs_to_scales_to_logits[output],
samples[common.LABEL], samples[common.LABEL],
...@@ -217,7 +218,7 @@ def main(unused_argv): ...@@ -217,7 +218,7 @@ def main(unused_argv):
assert FLAGS.train_batch_size % config.num_clones == 0, ( assert FLAGS.train_batch_size % config.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).') 'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size / config.num_clones clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)
# Get dataset-dependent information. # Get dataset-dependent information.
dataset = segmentation_dataset.get_dataset( dataset = segmentation_dataset.get_dataset(
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Utility functions for training.""" """Utility functions for training."""
import six
import tensorflow as tf import tensorflow as tf
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits, ...@@ -44,7 +46,7 @@ def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
if labels is None: if labels is None:
raise ValueError('No label for softmax cross entropy loss.') raise ValueError('No label for softmax cross entropy loss.')
for scale, logits in scales_to_logits.iteritems(): for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None loss_scope = None
if scope: if scope:
loss_scope = '%s_%s' % (scope, scale) loss_scope = '%s_%s' % (scope, scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment