Commit 906d712e authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 445390384
parent f58534eb
...@@ -861,8 +861,7 @@ class MobileNet(tf.keras.Model): ...@@ -861,8 +861,7 @@ class MobileNet(tf.keras.Model):
net = block(net) net = block(net)
elif block_def.block_fn == 'gpooling': elif block_def.block_fn == 'gpooling':
net = layers.GlobalAveragePooling2D()(net) net = layers.GlobalAveragePooling2D(keepdims=True)(net)
net = layers.Reshape((1, 1, net.shape[1]))(net)
else: else:
raise ValueError('Unknown block type {} for layer {}'.format( raise ValueError('Unknown block type {} for layer {}'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment