Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6f6bc501
Commit
6f6bc501
authored
Jul 28, 2017
by
Marianne Linhares Monteiro
Committed by
GitHub
Jul 28, 2017
Browse files
Create generate_cifar10_tfrecords.py
parent
7bea23de
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
0 deletions
+90
-0
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
...als/image/cifar10_estimator/generate_cifar10_tfrecords.py
+90
-0
No files found.
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
0 → 100644
View file @
6f6bc501
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read CIFAR-10 data from pickled numpy arrays and write TFExamples.
Generates TFRecord files from the python version of the CIFAR-10 dataset
downloaded from https://www.cs.toronto.edu/~kriz/cifar.html.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
cPickle
import
os
import
tensorflow
as
tf
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
'input_dir'
,
''
,
'Directory where CIFAR10 data is located.'
)
tf
.
flags
.
DEFINE_string
(
'output_dir'
,
''
,
'Directory where TFRecords will be saved.'
'The TFRecords will have the same name as'
' the CIFAR10 inputs + .tfrecords.'
)
def
_int64_feature
(
value
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
def
_bytes_feature
(
value
):
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
str
(
value
)]))
def
_get_file_names
():
"""Returns the file names expected to exist in the input_dir."""
file_names
=
[
'data_batch_%d'
%
i
for
i
in
xrange
(
1
,
6
)]
file_names
.
append
(
'test_batch'
)
return
file_names
def
read_pickle_from_file
(
filename
):
with
open
(
filename
,
'r'
)
as
f
:
data_dict
=
cPickle
.
load
(
f
)
return
data_dict
def
main
(
argv
):
del
argv
# Unused.
file_names
=
_get_file_names
()
for
file_name
in
file_names
:
input_file
=
os
.
path
.
join
(
FLAGS
.
input_dir
,
file_name
)
output_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
file_name
+
'.tfrecords'
)
print
(
'Generating %s'
%
output_file
)
record_writer
=
tf
.
python_io
.
TFRecordWriter
(
output_file
)
data_dict
=
read_pickle_from_file
(
input_file
)
data
=
data_dict
[
'data'
]
labels
=
data_dict
[
'labels'
]
num_entries_in_batch
=
len
(
labels
)
for
i
in
range
(
num_entries_in_batch
):
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image'
:
_bytes_feature
(
data
[
i
].
tobytes
()),
'label'
:
_int64_feature
(
labels
[
i
])
}))
record_writer
.
write
(
example
.
SerializeToString
())
record_writer
.
close
()
print
(
'Done!'
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment