Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
64f87cd2
Commit
64f87cd2
authored
Sep 06, 2019
by
A. Unique TensorFlower
Browse files
Merge pull request #7535 from houtoms:ctl_supports_amp
PiperOrigin-RevId: 267685527
parents
a629af4c
78047d54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
1 deletion
+83
-1
official/resnet/ctl/ctl_imagenet_benchmark.py
official/resnet/ctl/ctl_imagenet_benchmark.py
+65
-0
official/resnet/ctl/ctl_imagenet_main.py
official/resnet/ctl/ctl_imagenet_main.py
+18
-1
No files found.
official/resnet/ctl/ctl_imagenet_benchmark.py
View file @
64f87cd2
...
...
@@ -139,6 +139,21 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_amp
(
self
):
"""Test Keras model with eager, 8 GPUs with automatic mixed precision."""
self
.
_setup
()
FLAGS
.
num_gpus
=
8
FLAGS
.
data_dir
=
self
.
data_dir
FLAGS
.
batch_size
=
128
*
8
FLAGS
.
train_epochs
=
90
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp'
)
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
def
_run_and_report_benchmark
(
self
):
start_time_sec
=
time
.
time
()
stats
=
ctl_imagenet_main
.
run
(
flags
.
FLAGS
)
...
...
@@ -206,6 +221,31 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
batch_size
=
128
self
.
_run_and_report_benchmark
()
def
benchmark_1_gpu_amp
(
self
):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self
.
_setup
()
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'default'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_amp'
)
FLAGS
.
batch_size
=
128
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
self
.
_run_and_report_benchmark
()
def
benchmark_xla_1_gpu_amp
(
self
):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self
.
_setup
()
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'default'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_xla_1_gpu_amp'
)
FLAGS
.
batch_size
=
128
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
enable_xla
=
True
self
.
_run_and_report_benchmark
()
def
benchmark_1_gpu_eager
(
self
):
"""Test Keras model with 1 GPU in pure eager mode."""
self
.
_setup
()
...
...
@@ -228,6 +268,31 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
batch_size
=
128
*
8
# 8 GPUs
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_amp
(
self
):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self
.
_setup
()
FLAGS
.
num_gpus
=
8
FLAGS
.
distribution_strategy
=
'default'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp'
)
FLAGS
.
batch_size
=
128
*
8
# 8 GPUs
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
self
.
_run_and_report_benchmark
()
def
benchmark_xla_8_gpu_amp
(
self
):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self
.
_setup
()
FLAGS
.
num_gpus
=
8
FLAGS
.
distribution_strategy
=
'default'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_xla_8_gpu_amp'
)
FLAGS
.
batch_size
=
128
*
8
# 8 GPUs
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
enable_xla
=
True
self
.
_run_and_report_benchmark
()
def
fill_report_object
(
self
,
stats
):
super
(
Resnet50CtlBenchmarkBase
,
self
).
fill_report_object
(
stats
,
...
...
official/resnet/ctl/ctl_imagenet_main.py
View file @
64f87cd2
...
...
@@ -171,6 +171,14 @@ def run(flags_obj):
learning_rate
=
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
nesterov
=
True
)
if
flags_obj
.
fp16_implementation
==
"graph_rewrite"
:
if
not
flags_obj
.
use_tf_function
:
raise
ValueError
(
"--fp16_implementation=graph_rewrite requires "
"--use_tf_function to be true"
)
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
)
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
,
loss_scale
)
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'training_accuracy'
,
dtype
=
tf
.
float32
)
test_loss
=
tf
.
keras
.
metrics
.
Mean
(
'test_loss'
,
dtype
=
tf
.
float32
)
...
...
@@ -203,7 +211,17 @@ def run(flags_obj):
loss
+=
(
l2_loss
/
num_replicas
)
else
:
loss
+=
(
tf
.
reduce_sum
(
model
.
losses
)
/
num_replicas
)
# Scale the loss
if
flags_obj
.
dtype
==
"fp16"
:
loss
=
optimizer
.
get_scaled_loss
(
loss
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
# Unscale the grads
if
flags_obj
.
dtype
==
"fp16"
:
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
training_accuracy
.
update_state
(
labels
,
logits
)
...
...
@@ -296,6 +314,5 @@ if __name__ == '__main__':
logging
.
set_verbosity
(
logging
.
INFO
)
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
absl_app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment