Commit 4ca9e10a authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377092815
parent dfce8c78
......@@ -26,10 +26,6 @@ from __future__ import print_function
import functools
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
BATCH_NORM_DECAY = 0.997
BATCH_NORM_EPSILON = 1e-5
......@@ -57,48 +53,48 @@ def identity_building_block(input_tensor,
Output tensor for the block.
"""
filters1, filters2 = filters
if backend.image_data_format() == 'channels_last':
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(
x = tf.keras.layers.Conv2D(
filters1,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
x = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x, training=training)
x = layers.Activation('relu')(x)
x = tf.keras.layers.Activation('relu')(x)
x = layers.Conv2D(
x = tf.keras.layers.Conv2D(
filters2,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
x = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x, training=training)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
x = tf.keras.layers.add([x, input_tensor])
x = tf.keras.layers.Activation('relu')(x)
return x
......@@ -136,57 +132,57 @@ def conv_building_block(input_tensor,
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(
x = tf.keras.layers.Conv2D(
filters1,
kernel_size,
strides=strides,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(
input_tensor)
x = layers.BatchNormalization(
x = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(
x, training=training)
x = layers.Activation('relu')(x)
x = tf.keras.layers.Activation('relu')(x)
x = layers.Conv2D(
x = tf.keras.layers.Conv2D(
filters2,
kernel_size,
padding='same',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(
x)
x = layers.BatchNormalization(
x = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(
x, training=training)
shortcut = layers.Conv2D(
shortcut = tf.keras.layers.Conv2D(
filters2, (1, 1),
strides=strides,
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(
input_tensor)
shortcut = layers.BatchNormalization(
shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(
shortcut, training=training)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
x = tf.keras.layers.add([x, shortcut])
x = tf.keras.layers.Activation('relu')(x)
return x
......@@ -252,11 +248,11 @@ def resnet(num_blocks, classes=10, training=None):
"""
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
img_input = tf.keras.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
if tf.keras.backend.image_data_format() == 'channels_first':
x = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(
img_input)
bn_axis = 1
......@@ -264,23 +260,23 @@ def resnet(num_blocks, classes=10, training=None):
x = img_input
bn_axis = 3
x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = layers.Conv2D(
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(
16, (3, 3),
strides=(1, 1),
padding='valid',
use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name='conv1')(
x)
x = layers.BatchNormalization(
x = tf.keras.layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1',
)(x, training=training)
x = layers.Activation('relu')(x)
x = tf.keras.layers.Activation('relu')(x)
x = resnet_block(
x,
......@@ -309,14 +305,19 @@ def resnet(num_blocks, classes=10, training=None):
conv_strides=(2, 2),
training=training)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(
if tf.keras.backend.image_data_format() == 'channels_last':
rm_axes = [1, 2]
else:
rm_axes = [2, 3]
x = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.mean(x, rm_axes), name='reduce_mean')(x)
x = tf.keras.layers.Dense(
classes,
activation='softmax',
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
kernel_initializer=tf.keras.initializers.RandomNormal(
stddev=0.01),
kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
bias_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
name='fc10')(
x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment