Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6f881f77
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "1c5dca7f47d5bea8dcdaf989ffb3a5984de4d27a"
Commit
6f881f77
authored
Dec 20, 2018
by
Shining Sun
Browse files
bug fixes and clean ups
parent
b1b4c805
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
19 deletions
+38
-19
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+3
-3
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+11
-4
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+8
-3
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+12
-4
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+0
-1
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+4
-4
No files found.
official/resnet/cifar10_main.py
View file @
6f881f77
...
@@ -39,7 +39,7 @@ NUM_CLASSES = 10
...
@@ -39,7 +39,7 @@ NUM_CLASSES = 10
_NUM_DATA_FILES
=
5
_NUM_DATA_FILES
=
5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
_
NUM_IMAGES
=
{
NUM_IMAGES
=
{
'train'
:
50000
,
'train'
:
50000
,
'validation'
:
10000
,
'validation'
:
10000
,
}
}
...
@@ -134,7 +134,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
...
@@ -134,7 +134,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dataset
=
dataset
,
dataset
=
dataset
,
is_training
=
is_training
,
is_training
=
is_training
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle_buffer
=
_
NUM_IMAGES
[
'train'
],
shuffle_buffer
=
NUM_IMAGES
[
'train'
],
parse_record_fn
=
parse_record_fn
,
parse_record_fn
=
parse_record_fn
,
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -200,7 +200,7 @@ def cifar10_model_fn(features, labels, mode, params):
...
@@ -200,7 +200,7 @@ def cifar10_model_fn(features, labels, mode, params):
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
num_images
=
_
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
91
,
136
,
182
],
num_images
=
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
91
,
136
,
182
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
...
...
official/resnet/keras/keras_cifar_main.py
View file @
6f881f77
...
@@ -152,18 +152,24 @@ def run(flags_obj):
...
@@ -152,18 +152,24 @@ def run(flags_obj):
model
.
compile
(
loss
=
'categorical_crossentropy'
,
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
metrics
=
[
'categorical_accuracy'
],
str
ategy
=
strategy
)
di
str
ibute
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
,
cifar_main
.
NUM_IMAGES
[
'train'
])
learning_rate_schedule
,
cifar_main
.
NUM_IMAGES
[
'train'
])
steps_per_epoch
=
cifar_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_steps
=
cifar_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_epochs
=
flags_obj
.
train_epochs
if
flags_obj
.
train_steps
:
train_steps
=
min
(
flags_obj
.
train_steps
,
train_steps
)
train_epochs
=
1
num_eval_steps
=
(
cifar_main
.
NUM_IMAGES
[
'validation'
]
//
num_eval_steps
=
(
cifar_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
flags_obj
.
batch_size
)
history
=
model
.
fit
(
train_input_dataset
,
history
=
model
.
fit
(
train_input_dataset
,
epochs
=
flags_obj
.
train_epochs
,
epochs
=
train_epochs
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
train_steps
,
callbacks
=
[
callbacks
=
[
time_callback
,
time_callback
,
lr_callback
,
lr_callback
,
...
@@ -190,4 +196,5 @@ def main(_):
...
@@ -190,4 +196,5 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
cifar_main
.
define_cifar_flags
()
cifar_main
.
define_cifar_flags
()
keras_common
.
define_keras_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/resnet/keras/keras_common.py
View file @
6f881f77
...
@@ -56,8 +56,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -56,8 +56,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
%
self
.
log_batch_size
==
0
:
if
batch
%
self
.
log_batch_size
==
0
:
last_n_batches
=
time
.
time
()
-
self
.
batch_time_start
last_n_batches
=
time
.
time
()
-
self
.
batch_time_start
examples_per_second
=
examples_per_second
=
(
self
.
_batch_size
*
self
.
log_batch_size
)
/
last_n_batches
(
self
.
_batch_size
*
self
.
log_batch_size
)
/
last_n_batches
self
.
batch_times_secs
.
append
(
last_n_batches
)
self
.
batch_times_secs
.
append
(
last_n_batches
)
self
.
record_batch
=
True
self
.
record_batch
=
True
# TODO(anjalisridhar): add timestamp as well.
# TODO(anjalisridhar): add timestamp as well.
...
@@ -131,8 +130,14 @@ def analyze_fit_and_eval_result(history, eval_output):
...
@@ -131,8 +130,14 @@ def analyze_fit_and_eval_result(history, eval_output):
stats
[
'training_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
stats
[
'training_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
stats
[
'training_accuracy_top_1'
]
=
history
.
history
[
'categorical_accuracy'
][
-
1
]
stats
[
'training_accuracy_top_1'
]
=
history
.
history
[
'categorical_accuracy'
][
-
1
]
print
(
'Test loss:{}'
.
format
(
stats
[
''
]))
print
(
'Test loss:{}'
.
format
(
stats
[
'
eval_loss
'
]))
print
(
'top_1 accuracy:{}'
.
format
(
stats
[
'accuracy_top_1'
]))
print
(
'top_1 accuracy:{}'
.
format
(
stats
[
'accuracy_top_1'
]))
print
(
'top_1_training_accuracy:{}'
.
format
(
stats
[
'training_accuracy_top_1'
]))
print
(
'top_1_training_accuracy:{}'
.
format
(
stats
[
'training_accuracy_top_1'
]))
return
stats
return
stats
def
define_keras_flags
():
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_integer
(
name
=
"train_steps"
,
default
=
None
,
help
=
"The number of steps to run for training"
)
official/resnet/keras/keras_imagenet_main.py
View file @
6f881f77
...
@@ -68,12 +68,12 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
...
@@ -68,12 +68,12 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
def
parse_record_keras
(
raw_record
,
is_training
,
dtype
):
def
parse_record_keras
(
raw_record
,
is_training
,
dtype
):
"""Adjust the shape of label."""
"""Adjust the shape of label."""
image
,
label
=
imagenet_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
image
,
label
=
imagenet_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
# Keras model.
label
=
tf
.
cast
(
tf
.
cast
(
tf
.
reshape
(
label
,
shape
=
[
1
]),
dtype
=
tf
.
int32
)
-
1
,
label
=
tf
.
cast
(
tf
.
cast
(
tf
.
reshape
(
label
,
shape
=
[
1
]),
dtype
=
tf
.
int32
)
-
1
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
return
image
,
label
return
image
,
label
...
@@ -153,9 +153,16 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -153,9 +153,16 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps
=
(
imagenet_main
.
NUM_IMAGES
[
'validation'
]
//
num_eval_steps
=
(
imagenet_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
flags_obj
.
batch_size
)
train_steps
=
imagenet_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_epochs
=
flags_obj
.
train_epochs
if
flags_obj
.
train_steps
:
train_steps
=
min
(
flags_obj
.
train_steps
,
train_steps
)
train_epochs
=
1
history
=
model
.
fit
(
train_input_dataset
,
history
=
model
.
fit
(
train_input_dataset
,
epochs
=
flags_obj
.
train_epochs
,
epochs
=
train_epochs
,
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
train_steps
,
callbacks
=
[
callbacks
=
[
time_callback
,
time_callback
,
lr_callback
,
lr_callback
,
...
@@ -182,4 +189,5 @@ def main(_):
...
@@ -182,4 +189,5 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
imagenet_main
.
define_imagenet_flags
()
imagenet_main
.
define_imagenet_flags
()
keras_common
.
define_keras_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/resnet/resnet_run_loop.py
View file @
6f881f77
...
@@ -632,7 +632,6 @@ def define_resnet_flags(resnet_size_choices=None):
...
@@ -632,7 +632,6 @@ def define_resnet_flags(resnet_size_choices=None):
name
=
'use_one_device_strategy'
,
default
=
True
,
name
=
'use_one_device_strategy'
,
default
=
True
,
help
=
flags_core
.
help_wrap
(
'Set to False to not use distribution '
help
=
flags_core
.
help_wrap
(
'Set to False to not use distribution '
'strategies.'
))
'strategies.'
))
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_boolean
(
name
=
'use_tf_momentum_optimizer'
,
default
=
False
,
flags
.
DEFINE_boolean
(
name
=
'use_tf_momentum_optimizer'
,
default
=
False
,
help
=
'Use tf MomentumOptimizer.'
)
help
=
'Use tf MomentumOptimizer.'
)
...
...
official/utils/misc/distribution_utils.py
View file @
6f881f77
...
@@ -22,7 +22,7 @@ import tensorflow as tf
...
@@ -22,7 +22,7 @@ import tensorflow as tf
def
get_distribution_strategy
(
def
get_distribution_strategy
(
num_gpus
,
all_reduce_alg
=
None
,
use_one_device_strategy
):
num_gpus
,
all_reduce_alg
=
None
,
use_one_device_strategy
=
True
):
"""Return a DistributionStrategy for running the model.
"""Return a DistributionStrategy for running the model.
Args:
Args:
...
@@ -31,8 +31,8 @@ def get_distribution_strategy(
...
@@ -31,8 +31,8 @@ def get_distribution_strategy(
See tf.contrib.distribute.AllReduceCrossDeviceOps for available
See tf.contrib.distribute.AllReduceCrossDeviceOps for available
algorithms. If None, DistributionStrategy will choose based on device
algorithms. If None, DistributionStrategy will choose based on device
topology.
topology.
use_one_device_strategy: Should only be set to Truen when num_gpus is 1.
use_one_device_strategy: Should only be set to Truen when num_gpus is 1.
If True, then use OneDeviceStrategy; otherwise, do not use any
If True, then use OneDeviceStrategy; otherwise, do not use any
distribution strategy.
distribution strategy.
Returns:
Returns:
...
@@ -47,7 +47,7 @@ def get_distribution_strategy(
...
@@ -47,7 +47,7 @@ def get_distribution_strategy(
elif
num_gpus
==
1
:
elif
num_gpus
==
1
:
return
None
return
None
elif
use_one_device_strategy
:
elif
use_one_device_strategy
:
rase
ValueError
(
"When %d GPUs are specified, use_one_device_strategy"
ra
i
se
ValueError
(
"When %d GPUs are specified, use_one_device_strategy"
" flag cannot be set to True."
.
format
(
num_gpus
))
" flag cannot be set to True."
.
format
(
num_gpus
))
else
:
# num_gpus > 1 and not use_one_device_strategy
else
:
# num_gpus > 1 and not use_one_device_strategy
if
all_reduce_alg
:
if
all_reduce_alg
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment