Commit 7a257585 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Readability: Avoid global variables in resnet construction.

PiperOrigin-RevId: 302932162
parent ad09cf49
......@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from official.vision.image_classification.resnet import imagenet_preprocessing
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None
def identity_block(input_tensor,
......@@ -51,7 +47,9 @@ def identity_block(input_tensor,
filters,
stage,
block,
use_l2_regularizer=True):
use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""The identity block is the block that has no conv layer at shortcut.
Args:
......@@ -61,6 +59,8 @@ def identity_block(input_tensor,
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Output tensor for the block.
......@@ -82,8 +82,8 @@ def identity_block(input_tensor,
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
......@@ -99,8 +99,8 @@ def identity_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
......@@ -114,8 +114,8 @@ def identity_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')(
x)
......@@ -130,7 +130,9 @@ def conv_block(input_tensor,
stage,
block,
strides=(2, 2),
use_l2_regularizer=True):
use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""A block that has a conv layer at shortcut.
Note that from stage 3,
......@@ -145,6 +147,8 @@ def conv_block(input_tensor,
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Output tensor for the block.
......@@ -166,8 +170,8 @@ def conv_block(input_tensor,
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
......@@ -184,8 +188,8 @@ def conv_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
......@@ -199,8 +203,8 @@ def conv_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')(
x)
......@@ -214,8 +218,8 @@ def conv_block(input_tensor,
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '1')(
shortcut)
......@@ -227,7 +231,9 @@ def conv_block(input_tensor,
def resnet50(num_classes,
batch_size=None,
use_l2_regularizer=True,
rescale_inputs=False):
rescale_inputs=False,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""Instantiates the ResNet50 architecture.
Args:
......@@ -235,6 +241,8 @@ def resnet50(num_classes,
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
rescale_inputs: whether to rescale inputs from 0 to 1.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
A Keras model instance.
......@@ -260,6 +268,10 @@ def resnet50(num_classes,
else: # channels_last
bn_axis = 3
block_config = dict(
use_l2_regularizer=use_l2_regularizer,
batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(
64, (7, 7),
......@@ -272,113 +284,33 @@ def resnet50(num_classes,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(
x,
3, [64, 64, 256],
stage=2,
block='a',
strides=(1, 1),
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)
x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config)
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config)
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(
......
......@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync
l2_weight_decay = 1e-4
if self.flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
l2_loss = l2_weight_decay * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in self.model.trainable_variables
if 'bn' not in v.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment