Commit b86ffb12 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 301639338
parent febaae9a
......@@ -255,9 +255,7 @@ def resnet50(num_classes,
x = img_input
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(x)
x = layers.Permute((3, 1, 2))(x)
bn_axis = 1
else: # channels_last
bn_axis = 3
......@@ -382,8 +380,7 @@ def resnet50(num_classes,
block='c',
use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(
num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment