Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7124ed12
Commit
7124ed12
authored
Jan 08, 2020
by
Jaehong Kim
Committed by
A. Unique TensorFlower
Jan 08, 2020
Browse files
Add a flag to switch layers module.
PiperOrigin-RevId: 288841926
parent
3bb3f185
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
1 deletion
+36
-1
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+6
-0
official/vision/image_classification/resnet_ctl_imagenet_main.py
...l/vision/image_classification/resnet_ctl_imagenet_main.py
+1
-0
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+1
-0
official/vision/image_classification/resnet_model.py
official/vision/image_classification/resnet_model.py
+28
-1
No files found.
official/vision/image_classification/common.py
View file @
7124ed12
...
...
@@ -327,6 +327,12 @@ def define_keras_flags(dynamic_loss_scale=True):
help
=
'Number of steps per graph-mode loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.'
)
flags
.
DEFINE_boolean
(
name
=
'use_tf_keras_layers'
,
default
=
False
,
help
=
'Whether to use tf.keras.layers instead of tf.python.keras.layers.'
'It only changes imagenet resnet model layers for now. This flag is '
'a temporal flag during transition to tf.keras.layers. Do not use this '
'flag for external usage. this will be removed shortly.'
)
def
get_synth_data
(
height
,
width
,
num_channels
,
num_classes
,
dtype
):
...
...
official/vision/image_classification/resnet_ctl_imagenet_main.py
View file @
7124ed12
...
...
@@ -230,6 +230,7 @@ def run(flags_obj):
flags_obj
.
log_steps
)
with
distribution_utils
.
get_strategy_scope
(
strategy
):
resnet_model
.
change_keras_layer
(
flags_obj
.
use_tf_keras_layers
)
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
batch_size
=
flags_obj
.
batch_size
,
...
...
official/vision/image_classification/resnet_imagenet_main.py
View file @
7124ed12
...
...
@@ -170,6 +170,7 @@ def run(flags_obj):
model
=
trivial_model
.
trivial_model
(
imagenet_preprocessing
.
NUM_CLASSES
)
else
:
resnet_model
.
change_keras_layer
(
flags_obj
.
use_tf_keras_layers
)
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
)
...
...
official/vision/image_classification/resnet_model.py
View file @
7124ed12
...
...
@@ -27,9 +27,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
initializers
from
tensorflow.python.keras
import
layers
from
tensorflow.python.keras
import
layers
as
tf_python_keras_layers
from
tensorflow.python.keras
import
models
from
tensorflow.python.keras
import
regularizers
from
official.vision.image_classification
import
imagenet_preprocessing
...
...
@@ -38,6 +40,31 @@ L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY
=
0.9
BATCH_NORM_EPSILON
=
1e-5
layers
=
tf_python_keras_layers
def
change_keras_layer
(
use_tf_keras_layers
=
False
):
"""Change layers to either tf.keras.layers or tf.python.keras.layers.
Layer version of tf.keras.layers is depends on tensorflow version, but
tf.python.keras.layers checks environment variable TF2_BEHAVIOR.
This function is a temporal function to use tf.keras.layers.
Currently, tf v2 batchnorm layer is slower than tf v1 batchnorm layer.
this function is useful for tracking benchmark result for each version.
This function will be removed when we use tf.keras.layers as default.
TODO(b/146939027): Remove this function when tf v2 batchnorm reaches training
speed parity with tf v1 batchnorm.
Args:
use_tf_keras_layers: whether to use tf.keras.layers.
"""
global
layers
if
use_tf_keras_layers
:
layers
=
tf
.
keras
.
layers
else
:
layers
=
tf_python_keras_layers
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
):
return
regularizers
.
l2
(
L2_WEIGHT_DECAY
)
if
use_l2_regularizer
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment