Commit 42420ce8 authored by Shining Sun's avatar Shining Sun
Browse files

Clean up keras_resnet_model and rename it to resnet56

parent cd034c8c
......@@ -28,7 +28,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common
from official.resnet.keras import keras_resnet_model
from official.resnet.keras import resnet56
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -167,10 +167,9 @@ def run_cifar_with_keras(flags_obj):
opt, loss, accuracy = keras_common.get_optimizer_loss_and_metrics()
strategy = keras_common.get_dist_strategy()
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3),
include_top=True,
classes=cifar_main._NUM_CLASSES,
weights=None)
model = resnet56.ResNet56(input_shape=(32, 32, 3),
classes=cifar_main._NUM_CLASSES)
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment