"tests/vscode:/vscode.git/clone" did not exist on "d0c30cfd37c2b4e3c9e9cec6887f13c63a3e684e"
Commit cf5f1321 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300447325
parent b39958cf
...@@ -1305,7 +1305,6 @@ class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase): ...@@ -1305,7 +1305,6 @@ class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase):
'model': 'resnet50_v1.5', 'model': 'resnet50_v1.5',
'optimizer': 'mobilenet_default', 'optimizer': 'mobilenet_default',
'initial_learning_rate_per_sample': 0.0000039, 'initial_learning_rate_per_sample': 0.0000039,
'use_tf_keras_layers': True,
'pretrained_filepath': tf.train.latest_checkpoint( 'pretrained_filepath': tf.train.latest_checkpoint(
os.path.join(root_data_dir, 'resnet50')), os.path.join(root_data_dir, 'resnet50')),
'pruning_begin_step': 0, 'pruning_begin_step': 0,
...@@ -1369,7 +1368,6 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase): ...@@ -1369,7 +1368,6 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
default_flags = { default_flags = {
'model': 'resnet50_v1.5', 'model': 'resnet50_v1.5',
'optimizer': 'mobilenet_default', 'optimizer': 'mobilenet_default',
'use_tf_keras_layers': True,
} }
super(Resnet50KerasPruningBenchmarkReal, self).__init__( super(Resnet50KerasPruningBenchmarkReal, self).__init__(
default_flags=default_flags, **kwargs) default_flags=default_flags, **kwargs)
......
...@@ -275,12 +275,6 @@ def define_keras_flags( ...@@ -275,12 +275,6 @@ def define_keras_flags(
help='Whether to build a tf.while_loop inside the training loop on the ' help='Whether to build a tf.while_loop inside the training loop on the '
'host. Setting it to True is critical to have peak performance on ' 'host. Setting it to True is critical to have peak performance on '
'TPU.') 'TPU.')
flags.DEFINE_boolean(
name='use_tf_keras_layers', default=False,
help='Whether to use tf.keras.layers instead of tf.python.keras.layers.'
'It only changes imagenet resnet model layers for now. This flag is '
'a temporal flag during transition to tf.keras.layers. Do not use this '
'flag for external usage. this will be removed shortly.')
if model: if model:
flags.DEFINE_string('model', 'resnet50_v1.5', flags.DEFINE_string('model', 'resnet50_v1.5',
......
...@@ -31,7 +31,6 @@ import tensorflow as tf ...@@ -31,7 +31,6 @@ import tensorflow as tf
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers as tf_python_keras_layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
from official.vision.image_classification import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
...@@ -40,30 +39,7 @@ L2_WEIGHT_DECAY = 1e-4 ...@@ -40,30 +39,7 @@ L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9 BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_EPSILON = 1e-5
layers = tf_python_keras_layers layers = tf.keras.layers
def change_keras_layer(use_tf_keras_layers=False):
"""Change layers to either tf.keras.layers or tf.python.keras.layers.
Layer version of tf.keras.layers is depends on tensorflow version, but
tf.python.keras.layers checks environment variable TF2_BEHAVIOR.
This function is a temporal function to use tf.keras.layers.
Currently, tf v2 batchnorm layer is slower than tf v1 batchnorm layer.
this function is useful for tracking benchmark result for each version.
This function will be removed when we use tf.keras.layers as default.
TODO(b/146939027): Remove this function when tf v2 batchnorm reaches training
speed parity with tf v1 batchnorm.
Args:
use_tf_keras_layers: whether to use tf.keras.layers.
"""
global layers
if use_tf_keras_layers:
layers = tf.keras.layers
else:
layers = tf_python_keras_layers
def _gen_l2_regularizer(use_l2_regularizer=True): def _gen_l2_regularizer(use_l2_regularizer=True):
......
...@@ -183,7 +183,6 @@ def run(flags_obj): ...@@ -183,7 +183,6 @@ def run(flags_obj):
model = trivial_model.trivial_model( model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES) imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'resnet50_v1.5': elif flags_obj.model == 'resnet50_v1.5':
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES) num_classes=imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'mobilenet': elif flags_obj.model == 'mobilenet':
......
...@@ -50,7 +50,6 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -50,7 +50,6 @@ class KerasImagenetTest(tf.test.TestCase):
"-model", "resnet50_v1.5", "-model", "resnet50_v1.5",
"-optimizer", "resnet50_default", "-optimizer", "resnet50_default",
"-pruning_method", "polynomial_decay", "-pruning_method", "polynomial_decay",
"-use_tf_keras_layers", "true",
], ],
"mobilenet": [ "mobilenet": [
"-model", "mobilenet", "-model", "mobilenet",
......
...@@ -70,7 +70,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -70,7 +70,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
else: else:
self.input_fn = imagenet_preprocessing.input_fn self.input_fn = imagenet_preprocessing.input_fn
resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
self.model = resnet_model.resnet50( self.model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment