Commit 2bb1baad authored by xiangjinwu's avatar xiangjinwu Committed by Sergio Guadarrama
Browse files

slim Python 3 compatibility: cPickle and str/bytes (#1534)

parent 7e0016c5
...@@ -124,7 +124,7 @@ def read_label_file(dataset_dir, filename=LABELS_FILENAME): ...@@ -124,7 +124,7 @@ def read_label_file(dataset_dir, filename=LABELS_FILENAME):
A map from a label (integer) to class name. A map from a label (integer) to class name.
""" """
labels_filename = os.path.join(dataset_dir, filename) labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'r') as f: with tf.gfile.Open(labels_filename, 'rb') as f:
lines = f.read().decode() lines = f.read().decode()
lines = lines.split('\n') lines = lines.split('\n')
lines = filter(None, lines) lines = filter(None, lines)
......
...@@ -26,7 +26,7 @@ from __future__ import absolute_import ...@@ -26,7 +26,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import cPickle from six.moves import cPickle
import os import os
import sys import sys
import tarfile import tarfile
...@@ -72,14 +72,17 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0): ...@@ -72,14 +72,17 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
Returns: Returns:
The new offset. The new offset.
""" """
with tf.gfile.Open(filename, 'r') as f: with tf.gfile.Open(filename, 'rb') as f:
if sys.version_info < (3,):
data = cPickle.load(f) data = cPickle.load(f)
else:
data = cPickle.load(f, encoding='bytes')
images = data['data'] images = data[b'data']
num_images = images.shape[0] num_images = images.shape[0]
images = images.reshape((num_images, 3, 32, 32)) images = images.reshape((num_images, 3, 32, 32))
labels = data['labels'] labels = data[b'labels']
with tf.Graph().as_default(): with tf.Graph().as_default():
image_placeholder = tf.placeholder(dtype=tf.uint8) image_placeholder = tf.placeholder(dtype=tf.uint8)
...@@ -99,7 +102,7 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0): ...@@ -99,7 +102,7 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
feed_dict={image_placeholder: image}) feed_dict={image_placeholder: image})
example = dataset_utils.image_to_tfexample( example = dataset_utils.image_to_tfexample(
png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label) png_string, b'png', _IMAGE_SIZE, _IMAGE_SIZE, label)
tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.write(example.SerializeToString())
return offset + num_images return offset + num_images
......
...@@ -136,14 +136,14 @@ def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): ...@@ -136,14 +136,14 @@ def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
sys.stdout.flush() sys.stdout.flush()
# Read the filename: # Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'r').read() image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data) height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i])) class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = class_names_to_ids[class_name] class_id = class_names_to_ids[class_name]
example = dataset_utils.image_to_tfexample( example = dataset_utils.image_to_tfexample(
image_data, 'jpg', height, width, class_id) image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n') sys.stdout.write('\n')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment