Unverified Commit 1529b82c authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Replace GlobalAvgPooling with reduce mean to reduce cross-device overhead (#6837)

parent 76256146
......@@ -230,7 +230,8 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(
num_classes,
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment