"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "a76e250f097d7c7b79cfbaea8b4d426463b0218f"
Unverified Commit c0b31c51 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Fix trivial model to work properly with fp16 (#6760)

* Fix trivial model to work properly with fp16

* Add comment on manual casting
parent 5e876e6e
...@@ -201,7 +201,7 @@ def run(flags_obj): ...@@ -201,7 +201,7 @@ def run(flags_obj):
input_layer_batch_size = None input_layer_batch_size = None
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
else: else:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
......
...@@ -23,15 +23,19 @@ from tensorflow.python.keras import layers ...@@ -23,15 +23,19 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
def trivial_model(num_classes): def trivial_model(num_classes, dtype='float32'):
"""Trivial model for ImageNet dataset.""" """Trivial model for ImageNet dataset."""
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape) img_input = layers.Input(shape=input_shape, dtype=dtype)
x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]), x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]),
name='reshape')(img_input) name='reshape')(img_input)
x = layers.Dense(1, name='fc1')(x) x = layers.Dense(1, name='fc1')(x)
x = layers.Dense(num_classes, activation='softmax', name='fc1000')(x) x = layers.Dense(num_classes, name='fc1000')(x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x = backend.cast(x, 'float32')
x = layers.Activation('softmax')(x)
return models.Model(img_input, x, name='trivial') return models.Model(img_input, x, name='trivial')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment