"vscode:/vscode.git/clone" did not exist on "a5121e57d6f56b9c8b4b81838a382ae13cc3389b"
Unverified Commit b09685fe authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add trivial Keras model (#6460)

parent 0b2b8997
......@@ -253,6 +253,8 @@ def define_keras_flags():
"""Define flags for Keras models."""
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(
name='enable_xla', default=False,
help='Whether to enable XLA auto jit compilation. This is still an '
......
......@@ -25,6 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -164,8 +165,12 @@ def run(flags_obj):
# can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
else:
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A trivial model for Keras."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
def trivial_model(num_classes):
"""Trivial model for ImageNet dataset."""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]),
name='reshape')(img_input)
x = layers.Dense(num_classes, activation='softmax', name='fc1000')(x)
return models.Model(img_input, x, name='trivial')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment