Commit 7124ed12 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add a flag to switch layers module.

PiperOrigin-RevId: 288841926
parent 3bb3f185
......@@ -327,6 +327,12 @@ def define_keras_flags(dynamic_loss_scale=True):
help='Number of steps per graph-mode loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.')
flags.DEFINE_boolean(
name='use_tf_keras_layers', default=False,
help='Whether to use tf.keras.layers instead of tf.python.keras.layers.'
'It only changes imagenet resnet model layers for now. This flag is '
'a temporal flag during transition to tf.keras.layers. Do not use this '
'flag for external usage. this will be removed shortly.')
def get_synth_data(height, width, num_channels, num_classes, dtype):
......
......@@ -230,6 +230,7 @@ def run(flags_obj):
flags_obj.log_steps)
with distribution_utils.get_strategy_scope(strategy):
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size,
......
......@@ -170,6 +170,7 @@ def run(flags_obj):
model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES)
else:
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES)
......
......@@ -27,9 +27,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import layers as tf_python_keras_layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from official.vision.image_classification import imagenet_preprocessing
......@@ -38,6 +40,31 @@ L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
layers = tf_python_keras_layers
def change_keras_layer(use_tf_keras_layers=False):
"""Change layers to either tf.keras.layers or tf.python.keras.layers.
Layer version of tf.keras.layers is depends on tensorflow version, but
tf.python.keras.layers checks environment variable TF2_BEHAVIOR.
This function is a temporal function to use tf.keras.layers.
Currently, tf v2 batchnorm layer is slower than tf v1 batchnorm layer.
this function is useful for tracking benchmark result for each version.
This function will be removed when we use tf.keras.layers as default.
TODO(b/146939027): Remove this function when tf v2 batchnorm reaches training
speed parity with tf v1 batchnorm.
Args:
use_tf_keras_layers: whether to use tf.keras.layers.
"""
global layers
if use_tf_keras_layers:
layers = tf.keras.layers
else:
layers = tf_python_keras_layers
def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment