Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
34e79348
Unverified
Commit
34e79348
authored
Mar 23, 2018
by
Taylor Robie
Committed by
GitHub
Mar 23, 2018
Browse files
Arg parsing cleanup for MNIST and Wide-Deep (#3684)
* move wide_deep parser * move mnist parsers
parent
1d38a225
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
51 deletions
+52
-51
official/mnist/mnist.py
official/mnist/mnist.py
+19
-18
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+19
-20
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+14
-13
No files found.
official/mnist/mnist.py
View file @
34e79348
...
@@ -175,11 +175,14 @@ def validate_batch_size_for_multi_gpu(batch_size):
...
@@ -175,11 +175,14 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise
ValueError
(
err
)
raise
ValueError
(
err
)
def
main
(
_
):
def
main
(
argv
):
parser
=
MNISTArgParser
()
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
model_function
=
model_fn
model_function
=
model_fn
if
FLAGS
.
multi_gpu
:
if
flags
.
multi_gpu
:
validate_batch_size_for_multi_gpu
(
FLAGS
.
batch_size
)
validate_batch_size_for_multi_gpu
(
flags
.
batch_size
)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# and (2) wrap the optimizer. The first happens here, and (2) happens
...
@@ -187,16 +190,16 @@ def main(_):
...
@@ -187,16 +190,16 @@ def main(_):
model_function
=
tf
.
contrib
.
estimator
.
replicate_model_fn
(
model_function
=
tf
.
contrib
.
estimator
.
replicate_model_fn
(
model_fn
,
loss_reduction
=
tf
.
losses
.
Reduction
.
MEAN
)
model_fn
,
loss_reduction
=
tf
.
losses
.
Reduction
.
MEAN
)
data_format
=
FLAGS
.
data_format
data_format
=
flags
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
data_format
=
(
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
mnist_classifier
=
tf
.
estimator
.
Estimator
(
mnist_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_function
,
model_fn
=
model_function
,
model_dir
=
FLAGS
.
model_dir
,
model_dir
=
flags
.
model_dir
,
params
=
{
params
=
{
'data_format'
:
data_format
,
'data_format'
:
data_format
,
'multi_gpu'
:
FLAGS
.
multi_gpu
'multi_gpu'
:
flags
.
multi_gpu
})
})
# Set up training and evaluation input functions.
# Set up training and evaluation input functions.
...
@@ -206,35 +209,35 @@ def main(_):
...
@@ -206,35 +209,35 @@ def main(_):
# When choosing shuffle buffer sizes, larger sizes result in better
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
# enough dataset that we can easily shuffle the full epoch.
ds
=
dataset
.
train
(
FLAGS
.
data_dir
)
ds
=
dataset
.
train
(
flags
.
data_dir
)
ds
=
ds
.
cache
().
shuffle
(
buffer_size
=
50000
).
batch
(
FLAGS
.
batch_size
)
ds
=
ds
.
cache
().
shuffle
(
buffer_size
=
50000
).
batch
(
flags
.
batch_size
)
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
# during each training session.
ds
=
ds
.
repeat
(
FLAGS
.
epochs_between_evals
)
ds
=
ds
.
repeat
(
flags
.
epochs_between_evals
)
return
ds
return
ds
def
eval_input_fn
():
def
eval_input_fn
():
return
dataset
.
test
(
FLAGS
.
data_dir
).
batch
(
return
dataset
.
test
(
flags
.
data_dir
).
batch
(
FLAGS
.
batch_size
).
make_one_shot_iterator
().
get_next
()
flags
.
batch_size
).
make_one_shot_iterator
().
get_next
()
# Set up hook that outputs training logs every 100 steps.
# Set up hook that outputs training logs every 100 steps.
train_hooks
=
hooks_helper
.
get_train_hooks
(
train_hooks
=
hooks_helper
.
get_train_hooks
(
FLAGS
.
hooks
,
batch_size
=
FLAGS
.
batch_size
)
flags
.
hooks
,
batch_size
=
flags
.
batch_size
)
# Train and evaluate model.
# Train and evaluate model.
for
_
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_between_evals
):
for
_
in
range
(
flags
.
train_epochs
//
flags
.
epochs_between_evals
):
mnist_classifier
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
mnist_classifier
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
# Export the model
# Export the model
if
FLAGS
.
export_dir
is
not
None
:
if
flags
.
export_dir
is
not
None
:
image
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
image
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
input_fn
=
tf
.
estimator
.
export
.
build_raw_serving_input_receiver_fn
({
input_fn
=
tf
.
estimator
.
export
.
build_raw_serving_input_receiver_fn
({
'image'
:
image
,
'image'
:
image
,
})
})
mnist_classifier
.
export_savedmodel
(
FLAGS
.
export_dir
,
input_fn
)
mnist_classifier
.
export_savedmodel
(
flags
.
export_dir
,
input_fn
)
class
MNISTArgParser
(
argparse
.
ArgumentParser
):
class
MNISTArgParser
(
argparse
.
ArgumentParser
):
...
@@ -261,6 +264,4 @@ class MNISTArgParser(argparse.ArgumentParser):
...
@@ -261,6 +264,4 @@ class MNISTArgParser(argparse.ArgumentParser):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
parser
=
MNISTArgParser
()
main
(
argv
=
sys
.
argv
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
official/mnist/mnist_eager.py
View file @
34e79348
...
@@ -38,8 +38,6 @@ from official.mnist import dataset as mnist_dataset
...
@@ -38,8 +38,6 @@ from official.mnist import dataset as mnist_dataset
from
official.mnist
import
mnist
from
official.mnist
import
mnist
from
official.utils.arg_parsers
import
parsers
from
official.utils.arg_parsers
import
parsers
FLAGS
=
None
def
loss
(
logits
,
labels
):
def
loss
(
logits
,
labels
):
return
tf
.
reduce_mean
(
return
tf
.
reduce_mean
(
...
@@ -97,35 +95,38 @@ def test(model, dataset):
...
@@ -97,35 +95,38 @@ def test(model, dataset):
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
accuracy
.
result
())
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
accuracy
.
result
())
def
main
(
_
):
def
main
(
argv
):
parser
=
MNISTEagerArgParser
()
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
tfe
.
enable_eager_execution
()
tfe
.
enable_eager_execution
()
# Automatically determine device and data_format
# Automatically determine device and data_format
(
device
,
data_format
)
=
(
'/gpu:0'
,
'channels_first'
)
(
device
,
data_format
)
=
(
'/gpu:0'
,
'channels_first'
)
if
FLAGS
.
no_gpu
or
tfe
.
num_gpus
()
<=
0
:
if
flags
.
no_gpu
or
tfe
.
num_gpus
()
<=
0
:
(
device
,
data_format
)
=
(
'/cpu:0'
,
'channels_last'
)
(
device
,
data_format
)
=
(
'/cpu:0'
,
'channels_last'
)
# If data_format is defined in FLAGS, overwrite automatically set value.
# If data_format is defined in FLAGS, overwrite automatically set value.
if
FLAGS
.
data_format
is
not
None
:
if
flags
.
data_format
is
not
None
:
data_format
=
data_format
data_format
=
data_format
print
(
'Using device %s, and data format %s.'
%
(
device
,
data_format
))
print
(
'Using device %s, and data format %s.'
%
(
device
,
data_format
))
# Load the datasets
# Load the datasets
train_ds
=
mnist_dataset
.
train
(
FLAGS
.
data_dir
).
shuffle
(
60000
).
batch
(
train_ds
=
mnist_dataset
.
train
(
flags
.
data_dir
).
shuffle
(
60000
).
batch
(
FLAGS
.
batch_size
)
flags
.
batch_size
)
test_ds
=
mnist_dataset
.
test
(
FLAGS
.
data_dir
).
batch
(
FLAGS
.
batch_size
)
test_ds
=
mnist_dataset
.
test
(
flags
.
data_dir
).
batch
(
flags
.
batch_size
)
# Create the model and optimizer
# Create the model and optimizer
model
=
mnist
.
Model
(
data_format
)
model
=
mnist
.
Model
(
data_format
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
FLAGS
.
lr
,
FLAGS
.
momentum
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
flags
.
lr
,
flags
.
momentum
)
# Create file writers for writing TensorBoard summaries.
# Create file writers for writing TensorBoard summaries.
if
FLAGS
.
output_dir
:
if
flags
.
output_dir
:
# Create directories to which summaries will be written
# Create directories to which summaries will be written
# tensorboard --logdir=<output_dir>
# tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries.
# can then be used to see the recorded summaries.
train_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
'train'
)
train_dir
=
os
.
path
.
join
(
flags
.
output_dir
,
'train'
)
test_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
'eval'
)
test_dir
=
os
.
path
.
join
(
flags
.
output_dir
,
'eval'
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tf
.
gfile
.
MakeDirs
(
flags
.
output_dir
)
else
:
else
:
train_dir
=
None
train_dir
=
None
test_dir
=
None
test_dir
=
None
...
@@ -135,19 +136,19 @@ def main(_):
...
@@ -135,19 +136,19 @@ def main(_):
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
# Create and restore checkpoint (if one exists on the path)
# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'ckpt'
)
checkpoint_prefix
=
os
.
path
.
join
(
flags
.
model_dir
,
'ckpt'
)
step_counter
=
tf
.
train
.
get_or_create_global_step
()
step_counter
=
tf
.
train
.
get_or_create_global_step
()
checkpoint
=
tfe
.
Checkpoint
(
checkpoint
=
tfe
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
,
step_counter
=
step_counter
)
model
=
model
,
optimizer
=
optimizer
,
step_counter
=
step_counter
)
# Restore variables on creation if a checkpoint exists.
# Restore variables on creation if a checkpoint exists.
checkpoint
.
restore
(
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
checkpoint
.
restore
(
tf
.
train
.
latest_checkpoint
(
flags
.
model_dir
))
# Train and evaluate for a set number of epochs.
# Train and evaluate for a set number of epochs.
with
tf
.
device
(
device
):
with
tf
.
device
(
device
):
for
_
in
range
(
FLAGS
.
train_epochs
):
for
_
in
range
(
flags
.
train_epochs
):
start
=
time
.
time
()
start
=
time
.
time
()
with
summary_writer
.
as_default
():
with
summary_writer
.
as_default
():
train
(
model
,
optimizer
,
train_ds
,
step_counter
,
FLAGS
.
log_interval
)
train
(
model
,
optimizer
,
train_ds
,
step_counter
,
flags
.
log_interval
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
'
\n
Train time for epoch #%d (%d total steps): %f'
%
print
(
'
\n
Train time for epoch #%d (%d total steps): %f'
%
(
checkpoint
.
save_counter
.
numpy
()
+
1
,
(
checkpoint
.
save_counter
.
numpy
()
+
1
,
...
@@ -205,6 +206,4 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
...
@@ -205,6 +206,4 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
MNISTEagerArgParser
()
main
(
argv
=
sys
.
argv
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
official/wide_deep/wide_deep.py
View file @
34e79348
...
@@ -171,33 +171,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
...
@@ -171,33 +171,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return
dataset
return
dataset
def
main
(
_
):
def
main
(
argv
):
parser
=
WideDeepArgParser
()
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
# Clean up the model directory if present
# Clean up the model directory if present
shutil
.
rmtree
(
FLAGS
.
model_dir
,
ignore_errors
=
True
)
shutil
.
rmtree
(
flags
.
model_dir
,
ignore_errors
=
True
)
model
=
build_estimator
(
FLAGS
.
model_dir
,
FLAGS
.
model_type
)
model
=
build_estimator
(
flags
.
model_dir
,
flags
.
model_type
)
train_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'adult.data'
)
train_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.data'
)
test_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'adult.test'
)
test_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.test'
)
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
def
train_input_fn
():
def
train_input_fn
():
return
input_fn
(
train_file
,
FLAGS
.
epochs_per_eval
,
True
,
FLAGS
.
batch_size
)
return
input_fn
(
train_file
,
flags
.
epochs_per_eval
,
True
,
flags
.
batch_size
)
def
eval_input_fn
():
def
eval_input_fn
():
return
input_fn
(
test_file
,
1
,
False
,
FLAGS
.
batch_size
)
return
input_fn
(
test_file
,
1
,
False
,
flags
.
batch_size
)
train_hooks
=
hooks_helper
.
get_train_hooks
(
train_hooks
=
hooks_helper
.
get_train_hooks
(
FLAGS
.
hooks
,
batch_size
=
FLAGS
.
batch_size
,
flags
.
hooks
,
batch_size
=
flags
.
batch_size
,
tensors_to_log
=
{
'average_loss'
:
'head/truediv'
,
tensors_to_log
=
{
'average_loss'
:
'head/truediv'
,
'loss'
:
'head/weighted_loss/Sum'
})
'loss'
:
'head/weighted_loss/Sum'
})
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
for
n
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_between_evals
):
for
n
in
range
(
flags
.
train_epochs
//
flags
.
epochs_between_evals
):
model
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
model
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
# Display evaluation metrics
# Display evaluation metrics
print
(
'Results at epoch'
,
(
n
+
1
)
*
FLAGS
.
epochs_between_evals
)
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags
.
epochs_between_evals
)
print
(
'-'
*
60
)
print
(
'-'
*
60
)
for
key
in
sorted
(
results
):
for
key
in
sorted
(
results
):
...
@@ -224,6 +227,4 @@ class WideDeepArgParser(argparse.ArgumentParser):
...
@@ -224,6 +227,4 @@ class WideDeepArgParser(argparse.ArgumentParser):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
parser
=
WideDeepArgParser
()
main
(
argv
=
sys
.
argv
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment