"tests/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b0a412402bcd482937316848789843d36c1b42bc"
Commit d2a038fa authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264647492
parent dd03f167
......@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
......@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras.
......@@ -163,7 +167,8 @@ def run(flags_obj):
with strategy_scope:
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size)
dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9,
......@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs):
"""Training StepFn."""
def step_fn(inputs):
......@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) /
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits)
return loss
......
......@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block):
def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
"""The identity block is the block that has no conv layer at shortcut.
Args:
......@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
......@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
x = layers.Conv2D(
filters2,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
name=bn_name_base + '2c')(
x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
......@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters,
stage,
block,
strides=(2, 2)):
strides=(2, 2),
use_l2_regularizer=True):
"""A block that has a conv layer at shortcut.
Note that from stage 3,
......@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
......@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
x = layers.Conv2D(
filters2,
kernel_size,
strides=strides,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
name=bn_name_base + '2c')(
x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
shortcut = layers.Conv2D(
filters3, (1, 1),
strides=strides,
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '1')(
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut)
name=bn_name_base + '1')(
shortcut)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def resnet50(num_classes, dtype='float32', batch_size=None):
def resnet50(num_classes,
dtype='float32',
batch_size=None,
use_l2_regularizer=True):
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns:
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype,
batch_size=batch_size)
img_input = layers.Input(
shape=input_shape, dtype=dtype, batch_size=batch_size)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
x = layers.Lambda(
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1
else: # channels_last
x = img_input
bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
x = layers.Conv2D(
64, (7, 7),
strides=(2, 2),
padding='valid', use_bias=False,
padding='valid',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x)
name='bn_conv1')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = conv_block(
x,
3, [64, 64, 256],
stage=2,
block='a',
strides=(1, 1),
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(
num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(
x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment