Commit 7a257585 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Readability: Avoid global variables in resnet construction.

PiperOrigin-RevId: 302932162
parent ad09cf49
...@@ -35,15 +35,11 @@ from tensorflow.python.keras import models ...@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
from official.vision.image_classification.resnet import imagenet_preprocessing from official.vision.image_classification.resnet import imagenet_preprocessing
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
layers = tf.keras.layers layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True): def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None
def identity_block(input_tensor, def identity_block(input_tensor,
...@@ -51,7 +47,9 @@ def identity_block(input_tensor, ...@@ -51,7 +47,9 @@ def identity_block(input_tensor,
filters, filters,
stage, stage,
block, block,
use_l2_regularizer=True): use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
Args: Args:
...@@ -61,6 +59,8 @@ def identity_block(input_tensor, ...@@ -61,6 +59,8 @@ def identity_block(input_tensor,
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer. use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -82,8 +82,8 @@ def identity_block(input_tensor, ...@@ -82,8 +82,8 @@ def identity_block(input_tensor,
input_tensor) input_tensor)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')( name=bn_name_base + '2a')(
x) x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -99,8 +99,8 @@ def identity_block(input_tensor, ...@@ -99,8 +99,8 @@ def identity_block(input_tensor,
x) x)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')( name=bn_name_base + '2b')(
x) x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -114,8 +114,8 @@ def identity_block(input_tensor, ...@@ -114,8 +114,8 @@ def identity_block(input_tensor,
x) x)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')( name=bn_name_base + '2c')(
x) x)
...@@ -130,7 +130,9 @@ def conv_block(input_tensor, ...@@ -130,7 +130,9 @@ def conv_block(input_tensor,
stage, stage,
block, block,
strides=(2, 2), strides=(2, 2),
use_l2_regularizer=True): use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""A block that has a conv layer at shortcut. """A block that has a conv layer at shortcut.
Note that from stage 3, Note that from stage 3,
...@@ -145,6 +147,8 @@ def conv_block(input_tensor, ...@@ -145,6 +147,8 @@ def conv_block(input_tensor,
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block. strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer. use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns: Returns:
Output tensor for the block. Output tensor for the block.
...@@ -166,8 +170,8 @@ def conv_block(input_tensor, ...@@ -166,8 +170,8 @@ def conv_block(input_tensor,
input_tensor) input_tensor)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')( name=bn_name_base + '2a')(
x) x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -184,8 +188,8 @@ def conv_block(input_tensor, ...@@ -184,8 +188,8 @@ def conv_block(input_tensor,
x) x)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')( name=bn_name_base + '2b')(
x) x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
...@@ -199,8 +203,8 @@ def conv_block(input_tensor, ...@@ -199,8 +203,8 @@ def conv_block(input_tensor,
x) x)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')( name=bn_name_base + '2c')(
x) x)
...@@ -214,8 +218,8 @@ def conv_block(input_tensor, ...@@ -214,8 +218,8 @@ def conv_block(input_tensor,
input_tensor) input_tensor)
shortcut = layers.BatchNormalization( shortcut = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name=bn_name_base + '1')( name=bn_name_base + '1')(
shortcut) shortcut)
...@@ -227,7 +231,9 @@ def conv_block(input_tensor, ...@@ -227,7 +231,9 @@ def conv_block(input_tensor,
def resnet50(num_classes, def resnet50(num_classes,
batch_size=None, batch_size=None,
use_l2_regularizer=True, use_l2_regularizer=True,
rescale_inputs=False): rescale_inputs=False,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
Args: Args:
...@@ -235,6 +241,8 @@ def resnet50(num_classes, ...@@ -235,6 +241,8 @@ def resnet50(num_classes,
batch_size: Size of the batches for each step. batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer. use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
rescale_inputs: whether to rescale inputs from 0 to 1. rescale_inputs: whether to rescale inputs from 0 to 1.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns: Returns:
A Keras model instance. A Keras model instance.
...@@ -260,6 +268,10 @@ def resnet50(num_classes, ...@@ -260,6 +268,10 @@ def resnet50(num_classes,
else: # channels_last else: # channels_last
bn_axis = 3 bn_axis = 3
block_config = dict(
use_l2_regularizer=use_l2_regularizer,
batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D( x = layers.Conv2D(
64, (7, 7), 64, (7, 7),
...@@ -272,113 +284,33 @@ def resnet50(num_classes, ...@@ -272,113 +284,33 @@ def resnet50(num_classes,
x) x)
x = layers.BatchNormalization( x = layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=batch_norm_decay,
epsilon=BATCH_NORM_EPSILON, epsilon=batch_norm_epsilon,
name='bn_conv1')( name='bn_conv1')(
x) x)
x = layers.Activation('relu')(x) x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block( x = conv_block(
x, x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config)
3, [64, 64, 256], x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config)
stage=2, x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config)
block='a',
strides=(1, 1), x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config)
use_l2_regularizer=use_l2_regularizer) x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config)
x = identity_block( x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config)
x, x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config)
3, [64, 64, 256],
stage=2, x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config)
block='b', x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config)
use_l2_regularizer=use_l2_regularizer) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config)
x = identity_block( x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config)
x, x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config)
3, [64, 64, 256], x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config)
stage=2,
block='c', x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config)
use_l2_regularizer=use_l2_regularizer) x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = layers.GlobalAveragePooling2D()(x) x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense( x = layers.Dense(
......
...@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
loss = tf.reduce_sum(prediction_loss) * (1.0 / loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size) self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync num_replicas = self.strategy.num_replicas_in_sync
l2_weight_decay = 1e-4
if self.flags_obj.single_l2_loss_op: if self.flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ l2_loss = l2_weight_decay * 2 * tf.add_n([
tf.nn.l2_loss(v) tf.nn.l2_loss(v)
for v in self.model.trainable_variables for v in self.model.trainable_variables
if 'bn' not in v.name if 'bn' not in v.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment