Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e4a046e7
Commit
e4a046e7
authored
Mar 06, 2019
by
Reed
Committed by
Toby Boyd
Mar 06, 2019
Browse files
Mixed precision support (#6309)
* Mixed precision support * Add TODOs
parent
8367cf6d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
10 deletions
+22
-10
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+15
-7
official/resnet/keras/resnet_model.py
official/resnet/keras/resnet_model.py
+7
-3
No files found.
official/resnet/keras/keras_imagenet_main.py
View file @
e4a046e7
...
@@ -102,9 +102,9 @@ def run(flags_obj):
...
@@ -102,9 +102,9 @@ def run(flags_obj):
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'f
p
16'
:
if
dtype
==
'f
loat
16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'infer_float32_vars'
)
'value(fp32).'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
...
@@ -120,7 +120,7 @@ def run(flags_obj):
...
@@ -120,7 +120,7 @@ def run(flags_obj):
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
)
dtype
=
dtype
)
else
:
else
:
distribution_utils
.
undo_set_up_synthetic_data
()
distribution_utils
.
undo_set_up_synthetic_data
()
input_fn
=
imagenet_main
.
input_fn
input_fn
=
imagenet_main
.
input_fn
...
@@ -131,14 +131,16 @@ def run(flags_obj):
...
@@ -131,14 +131,16 @@ def run(flags_obj):
batch_size
=
flags_obj
.
batch_size
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
,
parse_record_fn
=
parse_record_keras
,
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
)
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
,
dtype
=
dtype
)
eval_input_dataset
=
input_fn
(
eval_input_dataset
=
input_fn
(
is_training
=
False
,
is_training
=
False
,
data_dir
=
flags_obj
.
data_dir
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
flags_obj
.
batch_size
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
)
parse_record_fn
=
parse_record_keras
,
dtype
=
dtype
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
...
@@ -148,7 +150,13 @@ def run(flags_obj):
...
@@ -148,7 +150,13 @@ def run(flags_obj):
with
strategy_scope
:
with
strategy_scope
:
optimizer
=
keras_common
.
get_optimizer
()
optimizer
=
keras_common
.
get_optimizer
()
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_main
.
NUM_CLASSES
)
if
dtype
==
'float16'
:
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
optimizer
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
))
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
dtype
)
model
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
model
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
...
...
official/resnet/keras/resnet_model.py
View file @
e4a046e7
...
@@ -174,7 +174,7 @@ def conv_block(input_tensor,
...
@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return
x
return
x
def
resnet50
(
num_classes
):
def
resnet50
(
num_classes
,
dtype
=
'float32'
):
# TODO(tfboyd): add training argument, just lik resnet56.
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
"""Instantiates the ResNet50 architecture.
...
@@ -185,7 +185,7 @@ def resnet50(num_classes):
...
@@ -185,7 +185,7 @@ def resnet50(num_classes):
A Keras model instance.
A Keras model instance.
"""
"""
input_shape
=
(
224
,
224
,
3
)
input_shape
=
(
224
,
224
,
3
)
img_input
=
layers
.
Input
(
shape
=
input_shape
)
img_input
=
layers
.
Input
(
shape
=
input_shape
,
dtype
=
dtype
)
if
backend
.
image_data_format
()
==
'channels_first'
:
if
backend
.
image_data_format
()
==
'channels_first'
:
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
...
@@ -232,10 +232,14 @@ def resnet50(num_classes):
...
@@ -232,10 +232,14 @@ def resnet50(num_classes):
x
=
layers
.
GlobalAveragePooling2D
(
name
=
'avg_pool'
)(
x
)
x
=
layers
.
GlobalAveragePooling2D
(
name
=
'avg_pool'
)(
x
)
x
=
layers
.
Dense
(
x
=
layers
.
Dense
(
num_classes
,
activation
=
'softmax'
,
num_classes
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'fc1000'
)(
x
)
name
=
'fc1000'
)(
x
)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x
=
backend
.
cast
(
x
,
'float32'
)
x
=
layers
.
Activation
(
'softmax'
)(
x
)
# Create model.
# Create model.
return
models
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
return
models
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment