Commit 4ca9e10a authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377092815
parent dfce8c78
...@@ -26,10 +26,6 @@ from __future__ import print_function ...@@ -26,10 +26,6 @@ from __future__ import print_function
import functools import functools
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
BATCH_NORM_DECAY = 0.997 BATCH_NORM_DECAY = 0.997
BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_EPSILON = 1e-5
...@@ -57,48 +53,48 @@ def identity_building_block(input_tensor, ...@@ -57,48 +53,48 @@ def identity_building_block(input_tensor,
Output tensor for the block. Output tensor for the block.
""" """
filters1, filters2 = filters filters1, filters2 = filters
if backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3 bn_axis = 3
else: else:
bn_axis = 1 bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D( x = tf.keras.layers.Conv2D(
filters1, filters1,
kernel_size, kernel_size,
padding='same', padding='same',
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')( name=conv_name_base + '2a')(
input_tensor) input_tensor)
x = layers.BatchNormalization( x = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')( name=bn_name_base + '2a')(
x, training=training) x, training=training)
x = layers.Activation('relu')(x) x = tf.keras.layers.Activation('relu')(x)
x = layers.Conv2D( x = tf.keras.layers.Conv2D(
filters2, filters2,
kernel_size, kernel_size,
padding='same', padding='same',
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')( name=conv_name_base + '2b')(
x) x)
x = layers.BatchNormalization( x = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')( name=bn_name_base + '2b')(
x, training=training) x, training=training)
x = layers.add([x, input_tensor]) x = tf.keras.layers.add([x, input_tensor])
x = layers.Activation('relu')(x) x = tf.keras.layers.Activation('relu')(x)
return x return x
...@@ -136,57 +132,57 @@ def conv_building_block(input_tensor, ...@@ -136,57 +132,57 @@ def conv_building_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D( x = tf.keras.layers.Conv2D(
filters1, filters1,
kernel_size, kernel_size,
strides=strides, strides=strides,
padding='same', padding='same',
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')( name=conv_name_base + '2a')(
input_tensor) input_tensor)
x = layers.BatchNormalization( x = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')( name=bn_name_base + '2a')(
x, training=training) x, training=training)
x = layers.Activation('relu')(x) x = tf.keras.layers.Activation('relu')(x)
x = layers.Conv2D( x = tf.keras.layers.Conv2D(
filters2, filters2,
kernel_size, kernel_size,
padding='same', padding='same',
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')( name=conv_name_base + '2b')(
x) x)
x = layers.BatchNormalization( x = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')( name=bn_name_base + '2b')(
x, training=training) x, training=training)
shortcut = layers.Conv2D( shortcut = tf.keras.layers.Conv2D(
filters2, (1, 1), filters2, (1, 1),
strides=strides, strides=strides,
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')( name=conv_name_base + '1')(
input_tensor) input_tensor)
shortcut = layers.BatchNormalization( shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')( name=bn_name_base + '1')(
shortcut, training=training) shortcut, training=training)
x = layers.add([x, shortcut]) x = tf.keras.layers.add([x, shortcut])
x = layers.Activation('relu')(x) x = tf.keras.layers.Activation('relu')(x)
return x return x
...@@ -252,11 +248,11 @@ def resnet(num_blocks, classes=10, training=None): ...@@ -252,11 +248,11 @@ def resnet(num_blocks, classes=10, training=None):
""" """
input_shape = (32, 32, 3) input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape) img_input = tf.keras.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
x = layers.Lambda( x = tf.keras.layers.Lambda(
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), lambda x: tf.keras.backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')( name='transpose')(
img_input) img_input)
bn_axis = 1 bn_axis = 1
...@@ -264,23 +260,23 @@ def resnet(num_blocks, classes=10, training=None): ...@@ -264,23 +260,23 @@ def resnet(num_blocks, classes=10, training=None):
x = img_input x = img_input
bn_axis = 3 bn_axis = 3
x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x) x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = layers.Conv2D( x = tf.keras.layers.Conv2D(
16, (3, 3), 16, (3, 3),
strides=(1, 1), strides=(1, 1),
padding='valid', padding='valid',
use_bias=False, use_bias=False,
kernel_initializer='he_normal', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name='conv1')( name='conv1')(
x) x)
x = layers.BatchNormalization( x = tf.keras.layers.BatchNormalization(
axis=bn_axis, axis=bn_axis,
momentum=BATCH_NORM_DECAY, momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON, epsilon=BATCH_NORM_EPSILON,
name='bn_conv1', name='bn_conv1',
)(x, training=training) )(x, training=training)
x = layers.Activation('relu')(x) x = tf.keras.layers.Activation('relu')(x)
x = resnet_block( x = resnet_block(
x, x,
...@@ -309,14 +305,19 @@ def resnet(num_blocks, classes=10, training=None): ...@@ -309,14 +305,19 @@ def resnet(num_blocks, classes=10, training=None):
conv_strides=(2, 2), conv_strides=(2, 2),
training=training) training=training)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] if tf.keras.backend.image_data_format() == 'channels_last':
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x) rm_axes = [1, 2]
x = layers.Dense( else:
rm_axes = [2, 3]
x = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.mean(x, rm_axes), name='reduce_mean')(x)
x = tf.keras.layers.Dense(
classes, classes,
activation='softmax', activation='softmax',
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=tf.keras.initializers.RandomNormal(
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY), stddev=0.01),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY), kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
bias_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name='fc10')( name='fc10')(
x) x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment