Commit b86ffb12 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 301639338
parent febaae9a
...@@ -255,9 +255,7 @@ def resnet50(num_classes, ...@@ -255,9 +255,7 @@ def resnet50(num_classes,
x = img_input x = img_input
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda( x = layers.Permute((3, 1, 2))(x)
lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(x)
bn_axis = 1 bn_axis = 1
else: # channels_last else: # channels_last
bn_axis = 3 bn_axis = 3
...@@ -382,8 +380,7 @@ def resnet50(num_classes, ...@@ -382,8 +380,7 @@ def resnet50(num_classes,
block='c', block='c',
use_l2_regularizer=use_l2_regularizer) use_l2_regularizer=use_l2_regularizer)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] x = layers.GlobalAveragePooling2D()(x)
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense( x = layers.Dense(
num_classes, num_classes,
kernel_initializer=initializers.RandomNormal(stddev=0.01), kernel_initializer=initializers.RandomNormal(stddev=0.01),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment