Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8008e72f
"...text-generation-inference.git" did not exist on "06edde94910594eef86988934cbbc43d775eb965"
Commit
8008e72f
authored
Aug 19, 2017
by
Toby Boyd
Browse files
Generate data downloads and creates files
parent
d8588a7e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
15 deletions
+25
-15
tutorials/image/cifar10_estimator/cifar10_utils.py
tutorials/image/cifar10_estimator/cifar10_utils.py
+2
-1
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
...als/image/cifar10_estimator/generate_cifar10_tfrecords.py
+23
-14
No files found.
tutorials/image/cifar10_estimator/cifar10_utils.py
View file @
8008e72f
import
collections
import
six
import
six
import
tensorflow
as
tf
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.core.framework
import
node_def_pb2
from
tensorflow.core.framework
import
node_def_pb2
from
tensorflow.python.framework
import
device
as
pydev
from
tensorflow.python.framework
import
device
as
pydev
from
tensorflow.python.training
import
basic_session_run_hooks
from
tensorflow.python.training
import
basic_session_run_hooks
...
...
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
View file @
8008e72f
...
@@ -26,8 +26,18 @@ import argparse
...
@@ -26,8 +26,18 @@ import argparse
import
cPickle
import
cPickle
import
os
import
os
import
tarfile
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
DATA_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR_10_FILE_NAME
=
'cifar-10-python.tar.gz'
CIFAR_LOCAL_FOLDER
=
'cifar-10-batches-py'
def
download_and_extract
(
data_dir
):
# download CIFAR-10 if not already downloaded.
tf
.
contrib
.
learn
.
datasets
.
base
.
maybe_download
(
CIFAR_10_FILE_NAME
,
data_dir
,
DATA_URL
)
tarfile
.
open
(
os
.
path
.
join
(
data_dir
,
CIFAR_10_FILE_NAME
),
'r:gz'
).
extractall
(
data_dir
)
def
_int64_feature
(
value
):
def
_int64_feature
(
value
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
...
@@ -57,6 +67,7 @@ def convert_to_tfrecord(input_files, output_file):
...
@@ -57,6 +67,7 @@ def convert_to_tfrecord(input_files, output_file):
print
(
'Generating %s'
%
output_file
)
print
(
'Generating %s'
%
output_file
)
with
tf
.
python_io
.
TFRecordWriter
(
output_file
)
as
record_writer
:
with
tf
.
python_io
.
TFRecordWriter
(
output_file
)
as
record_writer
:
for
input_file
in
input_files
:
for
input_file
in
input_files
:
print
(
input_file
)
data_dict
=
read_pickle_from_file
(
input_file
)
data_dict
=
read_pickle_from_file
(
input_file
)
data
=
data_dict
[
'data'
]
data
=
data_dict
[
'data'
]
labels
=
data_dict
[
'labels'
]
labels
=
data_dict
[
'labels'
]
...
@@ -71,12 +82,18 @@ def convert_to_tfrecord(input_files, output_file):
...
@@ -71,12 +82,18 @@ def convert_to_tfrecord(input_files, output_file):
record_writer
.
write
(
example
.
SerializeToString
())
record_writer
.
write
(
example
.
SerializeToString
())
def
main
(
input_dir
,
output_dir
):
def
main
(
data_dir
):
download_and_extract
(
data_dir
)
file_names
=
_get_file_names
()
file_names
=
_get_file_names
()
input_dir
=
os
.
path
.
join
(
data_dir
,
CIFAR_LOCAL_FOLDER
)
for
mode
,
files
in
file_names
.
items
():
for
mode
,
files
in
file_names
.
items
():
input_files
=
[
input_files
=
[
os
.
path
.
join
(
input_dir
,
f
)
for
f
in
files
]
os
.
path
.
join
(
input_dir
,
f
)
for
f
in
files
]
output_file
=
os
.
path
.
join
(
output_dir
,
mode
+
'.tfrecords'
)
output_file
=
os
.
path
.
join
(
data_dir
,
mode
+
'.tfrecords'
)
try
:
os
.
remove
(
output_file
)
except
OSError
:
pass
# Convert to Examples and write the result to TFRecords.
# Convert to Examples and write the result to TFRecords.
convert_to_tfrecord
(
input_files
,
output_file
)
convert_to_tfrecord
(
input_files
,
output_file
)
print
(
'Done!'
)
print
(
'Done!'
)
...
@@ -85,19 +102,11 @@ def main(input_dir, output_dir):
...
@@ -85,19 +102,11 @@ def main(input_dir, output_dir):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
'--input-dir'
,
'--data-dir'
,
type
=
str
,
default
=
''
,
help
=
'Directory where CIFAR10 data is located.'
)
parser
.
add_argument
(
'--output-dir'
,
type
=
str
,
type
=
str
,
default
=
''
,
default
=
''
,
help
=
"""
\
help
=
'Directory to download and extract CIFAR-10 to.'
Directory where TFRecords will be saved.The TFRecords will have the same
name as the CIFAR10 inputs + .tfrecords.
\
"""
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
input_dir
,
args
.
output
_dir
)
main
(
args
.
data
_dir
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment