Commit d2a038fa authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264647492
parent dd03f167
...@@ -27,3 +27,6 @@ def define_ctl_flags(): ...@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True, flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a ' help='Wrap the train and test step inside a '
'tf.function.') 'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
...@@ -137,6 +137,10 @@ def run(flags_obj): ...@@ -137,6 +137,10 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
...@@ -163,7 +167,8 @@ def run(flags_obj): ...@@ -163,7 +167,8 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size) dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9,
...@@ -175,6 +180,8 @@ def run(flags_obj): ...@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs): def train_step(train_ds_inputs):
"""Training StepFn.""" """Training StepFn."""
def step_fn(inputs): def step_fn(inputs):
...@@ -185,13 +192,22 @@ def run(flags_obj): ...@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits) labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size) loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) / num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2 if flags_obj.single_l2_loss_op:
filtered_variables = [
grads = tape.gradient(loss, model.trainable_variables) tf.reshape(v, (-1,))
optimizer.apply_gradients(zip(grads, model.trainable_variables)) for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits) training_accuracy.update_state(labels, logits)
return loss return loss
......
...@@ -28,7 +28,7 @@ from __future__ import division ...@@ -28,7 +28,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9 ...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block): def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
Args: Args:
...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters1, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2a')(input_tensor) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2a')(
epsilon=BATCH_NORM_EPSILON, input_tensor)
name=bn_name_base + '2a')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(
padding='same', use_bias=False, filters2,
kernel_initializer='he_normal', kernel_size,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), padding='same',
name=conv_name_base + '2b')(x) use_bias=False,
x = layers.BatchNormalization(axis=bn_axis, kernel_initializer='he_normal',
momentum=BATCH_NORM_DECAY, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
epsilon=BATCH_NORM_EPSILON, name=conv_name_base + '2b')(
name=bn_name_base + '2b')(x) x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters3, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2c')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2c')(
epsilon=BATCH_NORM_EPSILON, x)
name=bn_name_base + '2c')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(
x)
x = layers.add([x, input_tensor]) x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -100,7 +126,8 @@ def conv_block(input_tensor, ...@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters, filters,
stage, stage,
block, block,
strides=(2, 2)): strides=(2, 2),
use_l2_regularizer=True):
"""A block that has a conv layer at shortcut. """A block that has a conv layer at shortcut.
Note that from stage 3, Note that from stage 3,
...@@ -114,6 +141,7 @@ def conv_block(input_tensor, ...@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block. strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -126,114 +154,231 @@ def conv_block(input_tensor, ...@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters1, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2a')(input_tensor) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2a')(
epsilon=BATCH_NORM_EPSILON, input_tensor)
name=bn_name_base + '2a')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same', x = layers.Conv2D(
use_bias=False, kernel_initializer='he_normal', filters2,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_size,
name=conv_name_base + '2b')(x) strides=strides,
x = layers.BatchNormalization(axis=bn_axis, padding='same',
momentum=BATCH_NORM_DECAY, use_bias=False,
epsilon=BATCH_NORM_EPSILON, kernel_initializer='he_normal',
name=bn_name_base + '2b')(x) kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
kernel_initializer='he_normal', filters3, (1, 1),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name=conv_name_base + '2c')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name=conv_name_base + '2c')(
epsilon=BATCH_NORM_EPSILON, x)
name=bn_name_base + '2c')(x) x = layers.BatchNormalization(
axis=bn_axis,
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False, momentum=BATCH_NORM_DECAY,
kernel_initializer='he_normal', epsilon=BATCH_NORM_EPSILON,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), name=bn_name_base + '2c')(
name=conv_name_base + '1')(input_tensor) x)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY, shortcut = layers.Conv2D(
epsilon=BATCH_NORM_EPSILON, filters3, (1, 1),
name=bn_name_base + '1')(shortcut) strides=strides,
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '1')(
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(
shortcut)
x = layers.add([x, shortcut]) x = layers.add([x, shortcut])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
return x return x
def resnet50(num_classes, dtype='float32', batch_size=None): def resnet50(num_classes,
dtype='float32',
batch_size=None,
use_l2_regularizer=True):
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
Args: Args:
num_classes: `int` number of classes for image classification. num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common. dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step. batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype, img_input = layers.Input(
batch_size=batch_size) shape=input_shape, dtype=dtype, batch_size=batch_size)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(
name='transpose')(img_input) lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1 bn_axis = 1
else: # channels_last else: # channels_last
x = img_input x = img_input
bn_axis = 3 bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(
strides=(2, 2), 64, (7, 7),
padding='valid', use_bias=False, strides=(2, 2),
kernel_initializer='he_normal', padding='valid',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), use_bias=False,
name='conv1')(x) kernel_initializer='he_normal',
x = layers.BatchNormalization(axis=bn_axis, kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
momentum=BATCH_NORM_DECAY, name='conv1')(
epsilon=BATCH_NORM_EPSILON, x)
name='bn_conv1')(x) x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 3, [64, 64, 256],
stage=2,
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') block='a',
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') strides=(1, 1),
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') x = identity_block(
x,
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 3, [64, 64, 256],
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') stage=2,
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') block='b',
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') x,
3, [64, 64, 256],
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') stage=2,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') block='c',
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x) x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense( x = layers.Dense(
num_classes, num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(x) name='fc1000')(
x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a # TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code. # single line of code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment