Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a7531875
Commit
a7531875
authored
Aug 20, 2017
by
Toby Boyd
Browse files
Updated readme and cleaned up formatting
parent
8008e72f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
50 deletions
+30
-50
tutorials/image/cifar10_estimator/README.md
tutorials/image/cifar10_estimator/README.md
+10
-32
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
...als/image/cifar10_estimator/generate_cifar10_tfrecords.py
+20
-18
No files found.
tutorials/image/cifar10_estimator/README.md
View file @
a7531875
...
...
@@ -14,40 +14,18 @@ Before trying to run the model we highly encourage you to read all the README.
1.
Install TensorFlow version 1.2.1 or later with GPU support.
You can see how to do it
[
here
](
https://www.tensorflow.org/install/
)
.
2.
Download the CIFAR-10 dataset.
```
shell
curl
-o
cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar
xzf cifar-10-python.tar.gz
```
After running the commands above, you should see the following files in the folder where the data was downloaded.
```
shell
ls
-R
cifar-10-batches-py
```
The output should be:
```
batches.meta data_batch_1 data_batch_2 data_batch_3
data_batch_4 data_batch_5 readme.html test_batch
```
3.
Generate TFRecord files.
2.
Generate TFRecord files.
This will generate a tf record for the training and test data available at the input_dir.
You can see more details in
`generate_cifar10_tf_records.py`
```
shell
python generate_cifar10_tfrecords.py
--input-dir
=
${
PWD
}
/cifar-10-batches-py
\
--output-dir
=
${
PWD
}
/cifar-10-batches-py
python generate_cifar10_tfrecords.py
--data-dir
=
${
PWD
}
/cifar-10-data
```
After running the command above, you should see the following new files in the output_dir.
```
shell
ls
-R
cifar-10-
b
at
ches-py
ls
-R
cifar-10-
d
at
a
```
```
...
...
@@ -59,7 +37,7 @@ train.tfrecords validation.tfrecords eval.tfrecords
Run the model on CPU only. After training, it runs the evaluation.
```
python cifar10_main.py --data-dir=${PWD}/cifar-10-
b
at
ches-py
\
python cifar10_main.py --data-dir=${PWD}/cifar-10-
d
at
a
\
--job-dir=/tmp/cifar10 \
--num-gpus=0 \
--train-steps=1000
...
...
@@ -67,7 +45,7 @@ python cifar10_main.py --data-dir=${PWD}/cifar-10-batches-py \
Run the model on 2 GPUs using CPU as parameter server. After training, it runs the evaluation.
```
python cifar10_main.py --data-dir=${PWD}/cifar-10-
b
at
ches-py
\
python cifar10_main.py --data-dir=${PWD}/cifar-10-
d
at
a
\
--job-dir=/tmp/cifar10 \
--num-gpus=2 \
--train-steps=1000
...
...
@@ -78,7 +56,7 @@ It will run an experiment, which for local setting basically means it will run s
a couple of times to perform evaluation.
```
python cifar10_main.py --data-dir=${PWD}/cifar-10-
b
at
ches-bin
\
python cifar10_main.py --data-dir=${PWD}/cifar-10-
d
at
a
\
--job-dir=/tmp/cifar10 \
--variable-strategy GPU \
--num-gpus=2 \
...
...
@@ -98,7 +76,7 @@ You'll also need a Google Cloud Storage bucket for the data. If you followed the
```
MY_BUCKET=gs://<my-bucket-name>
gsutil cp -r ${PWD}/cifar-10-
b
at
ches-py
$MY_BUCKET/
gsutil cp -r ${PWD}/cifar-10-
d
at
a
$MY_BUCKET/
```
Then run the following command from the
`tutorials/image`
directory of this repository (the parent directory of this README):
...
...
@@ -111,7 +89,7 @@ gcloud ml-engine jobs submit training cifarmultigpu \
--package-path cifar10_estimator/ \
--module-name cifar10_estimator.cifar10_main \
-- \
--data-dir=$MY_BUCKET/cifar-10-
b
at
ches-py
\
--data-dir=$MY_BUCKET/cifar-10-
d
at
a
\
--num-gpus=4 \
--train-steps=1000
```
...
...
@@ -191,7 +169,7 @@ The num_workers arugument is used only to update the learning rate correctly.
Make sure the model_dir is the same as defined on the TF_CONFIG.
```
shell
python cifar10_main.py
--data-dir
=
gs://path/cifar-10-
b
at
ches-py
\
python cifar10_main.py
--data-dir
=
gs://path/cifar-10-
d
at
a
\
--job-dir
=
gs://path/model_dir/
\
--num-gpus
=
4
\
--train-steps
=
40000
\
...
...
@@ -332,7 +310,7 @@ It will run evaluation a couple of times during training.
Make sure the model_dir is the same as defined on the TF_CONFIG.
```
shell
python cifar10_main.py
--data-dir
=
gs://path/cifar-10-
b
at
ches-py
\
python cifar10_main.py
--data-dir
=
gs://path/cifar-10-
d
at
a
\
--job-dir
=
gs://path/model_dir/
\
--num-gpus
=
4
\
--train-steps
=
40000
\
...
...
tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
View file @
a7531875
...
...
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read CIFAR-10 data from pickled numpy arrays and write TF
Example
s.
"""Read CIFAR-10 data from pickled numpy arrays and write
s
TF
Record
s.
Generates TFRecord files from the python version of the CIFAR-10 dataset
downloaded from https://www.cs.toronto.edu/~kriz/cifar.html.
Generates tf.train.Example protos and writes them to TFRecord files from the
python version of the CIFAR-10 dataset downloaded from
https://www.cs.toronto.edu/~kriz/cifar.html.
"""
from
__future__
import
absolute_import
...
...
@@ -30,14 +31,18 @@ import tarfile
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
DATA_URL
=
'https://www.cs.toronto.edu/~kriz/
cifar-10-python.tar.gz'
CIFAR_
10_FILE_NAME
=
'cifar-10-python.tar.gz'
CIFAR_FILENAME
=
'
cifar-10-python.tar.gz'
CIFAR_
DOWNLOAD_URL
=
'https://www.cs.toronto.edu/~kriz/'
+
CIFAR_FILENAME
CIFAR_LOCAL_FOLDER
=
'cifar-10-batches-py'
def
download_and_extract
(
data_dir
):
# download CIFAR-10 if not already downloaded.
tf
.
contrib
.
learn
.
datasets
.
base
.
maybe_download
(
CIFAR_10_FILE_NAME
,
data_dir
,
DATA_URL
)
tarfile
.
open
(
os
.
path
.
join
(
data_dir
,
CIFAR_10_FILE_NAME
),
'r:gz'
).
extractall
(
data_dir
)
tf
.
contrib
.
learn
.
datasets
.
base
.
maybe_download
(
CIFAR_FILENAME
,
data_dir
,
CIFAR_DOWNLOAD_URL
)
tarfile
.
open
(
os
.
path
.
join
(
data_dir
,
CIFAR_FILENAME
),
'r:gz'
).
extractall
(
data_dir
)
def
_int64_feature
(
value
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
...
...
@@ -63,19 +68,17 @@ def read_pickle_from_file(filename):
def
convert_to_tfrecord
(
input_files
,
output_file
):
"""Converts a file to
tfr
ecords."""
"""Converts a file to
TFR
ecords."""
print
(
'Generating %s'
%
output_file
)
with
tf
.
python_io
.
TFRecordWriter
(
output_file
)
as
record_writer
:
for
input_file
in
input_files
:
print
(
input_file
)
data_dict
=
read_pickle_from_file
(
input_file
)
data
=
data_dict
[
'data'
]
labels
=
data_dict
[
'labels'
]
num_entries_in_batch
=
len
(
labels
)
for
i
in
range
(
num_entries_in_batch
):
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image'
:
_bytes_feature
(
data
[
i
].
tobytes
()),
'label'
:
_int64_feature
(
labels
[
i
])
}))
...
...
@@ -83,18 +86,18 @@ def convert_to_tfrecord(input_files, output_file):
def
main
(
data_dir
):
print
(
'Download from {} and extract.'
.
format
(
CIFAR_DOWNLOAD_URL
))
download_and_extract
(
data_dir
)
file_names
=
_get_file_names
()
input_dir
=
os
.
path
.
join
(
data_dir
,
CIFAR_LOCAL_FOLDER
)
for
mode
,
files
in
file_names
.
items
():
input_files
=
[
os
.
path
.
join
(
input_dir
,
f
)
for
f
in
files
]
input_files
=
[
os
.
path
.
join
(
input_dir
,
f
)
for
f
in
files
]
output_file
=
os
.
path
.
join
(
data_dir
,
mode
+
'.tfrecords'
)
try
:
os
.
remove
(
output_file
)
except
OSError
:
pass
# Convert to Example
s
and write the
result
to TFRecords.
# Convert to
tf.train.
Example and write the to TFRecords.
convert_to_tfrecord
(
input_files
,
output_file
)
print
(
'Done!'
)
...
...
@@ -105,8 +108,7 @@ if __name__ == '__main__':
'--data-dir'
,
type
=
str
,
default
=
''
,
help
=
'Directory to download and extract CIFAR-10 to.'
)
help
=
'Directory to download and extract CIFAR-10 to.'
)
args
=
parser
.
parse_args
()
main
(
args
.
data_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment