"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "9d86ca677c48b4e006156b1edce631e6c4bd5c12"
Commit d2a038fa authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264647492
parent dd03f167
...@@ -27,3 +27,6 @@ def define_ctl_flags(): ...@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True, flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a ' help='Wrap the train and test step inside a '
'tf.function.') 'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
...@@ -137,6 +137,10 @@ def run(flags_obj): ...@@ -137,6 +137,10 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
...@@ -163,7 +167,8 @@ def run(flags_obj): ...@@ -163,7 +167,8 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size) dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9,
...@@ -175,6 +180,8 @@ def run(flags_obj): ...@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs): def train_step(train_ds_inputs):
"""Training StepFn.""" """Training StepFn."""
def step_fn(inputs): def step_fn(inputs):
...@@ -185,13 +192,22 @@ def run(flags_obj): ...@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits) labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size) loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) / num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2 if flags_obj.single_l2_loss_op:
filtered_variables = [
grads = tape.gradient(loss, model.trainable_variables) tf.reshape(v, (-1,))
optimizer.apply_gradients(zip(grads, model.trainable_variables)) for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits) training_accuracy.update_state(labels, logits)
return loss return loss
......
...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9 ...@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block): def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
Args: Args:
...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): ...@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(
x = layers.BatchNormalization(axis=bn_axis, input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, x = layers.Conv2D(
padding='same', use_bias=False, filters2,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(
x = layers.BatchNormalization(axis=bn_axis, x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(
x = layers.BatchNormalization(axis=bn_axis, x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(
x)
x = layers.add([x, input_tensor]) x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -100,7 +126,8 @@ def conv_block(input_tensor, ...@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters, filters,
stage, stage,
block, block,
strides=(2, 2)): strides=(2, 2),
use_l2_regularizer=True):
"""A block that has a conv layer at shortcut. """A block that has a conv layer at shortcut.
Note that from stage 3, Note that from stage 3,
...@@ -114,6 +141,7 @@ def conv_block(input_tensor, ...@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block. strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -126,114 +154,231 @@ def conv_block(input_tensor, ...@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False, x = layers.Conv2D(
filters1, (1, 1),
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(
x = layers.BatchNormalization(axis=bn_axis, input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x) name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same', x = layers.Conv2D(
use_bias=False, kernel_initializer='he_normal', filters2,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_size,
name=conv_name_base + '2b')(x) strides=strides,
x = layers.BatchNormalization(axis=bn_axis, padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x) name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False, x = layers.Conv2D(
filters3, (1, 1),
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '2c')(x) name=conv_name_base + '2c')(
x = layers.BatchNormalization(axis=bn_axis, x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x) name=bn_name_base + '2c')(
x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False, shortcut = layers.Conv2D(
filters3, (1, 1),
strides=strides,
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(
shortcut = layers.BatchNormalization(axis=bn_axis, input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut) name=bn_name_base + '1')(
shortcut)
x = layers.add([x, shortcut]) x = layers.add([x, shortcut])
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
return x return x
def resnet50(num_classes, dtype='float32', batch_size=None): def resnet50(num_classes,
dtype='float32',
batch_size=None,
use_l2_regularizer=True):
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
Args: Args:
num_classes: `int` number of classes for image classification. num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common. dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step. batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype, img_input = layers.Input(
batch_size=batch_size) shape=input_shape, dtype=dtype, batch_size=batch_size)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(
name='transpose')(img_input) lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1 bn_axis = 1
else: # channels_last else: # channels_last
x = img_input x = img_input
bn_axis = 3 bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7), x = layers.Conv2D(
64, (7, 7),
strides=(2, 2), strides=(2, 2),
padding='valid', use_bias=False, padding='valid',
use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='conv1')(x) name='conv1')(
x = layers.BatchNormalization(axis=bn_axis, x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x) name='bn_conv1')(
x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x,
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 3, [64, 64, 256],
stage=2,
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') block='a',
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') strides=(1, 1),
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') x = identity_block(
x,
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 3, [64, 64, 256],
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') stage=2,
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') block='b',
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') use_l2_regularizer=use_l2_regularizer)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') x,
3, [64, 64, 256],
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') stage=2,
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') block='c',
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x) x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense( x = layers.Dense(
num_classes, num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(x) name='fc1000')(
x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a # TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code. # single line of code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment