Commit 30cf3752 authored by Joel Shor's avatar Joel Shor
Browse files

Made `generate_cifar10_tfrecords.py` python3 compatible.

parent f3a542b4
...@@ -25,6 +25,7 @@ from __future__ import print_function ...@@ -25,6 +25,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
import tarfile import tarfile
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
...@@ -63,7 +64,10 @@ def _get_file_names(): ...@@ -63,7 +64,10 @@ def _get_file_names():
def read_pickle_from_file(filename): def read_pickle_from_file(filename):
with tf.gfile.Open(filename, 'rb') as f: with tf.gfile.Open(filename, 'rb') as f:
data_dict = pickle.load(f) if sys.version_info >= (3, 0):
data_dict = pickle.load(f, encoding='bytes')
else:
data_dict = pickle.load(f)
return data_dict return data_dict
...@@ -73,8 +77,8 @@ def convert_to_tfrecord(input_files, output_file): ...@@ -73,8 +77,8 @@ def convert_to_tfrecord(input_files, output_file):
with tf.python_io.TFRecordWriter(output_file) as record_writer: with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files: for input_file in input_files:
data_dict = read_pickle_from_file(input_file) data_dict = read_pickle_from_file(input_file)
data = data_dict['data'] data = data_dict[b'data']
labels = data_dict['labels'] labels = data_dict[b'labels']
num_entries_in_batch = len(labels) num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch): for i in range(num_entries_in_batch):
example = tf.train.Example(features=tf.train.Features( example = tf.train.Example(features=tf.train.Features(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment