Commit 42420ce8 authored by Shining Sun's avatar Shining Sun
Browse files

Clean up keras_resnet_model and rename it to resnet56

parent cd034c8c
...@@ -28,7 +28,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -28,7 +28,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import keras_resnet_model from official.resnet.keras import resnet56
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -167,10 +167,9 @@ def run_cifar_with_keras(flags_obj): ...@@ -167,10 +167,9 @@ def run_cifar_with_keras(flags_obj):
opt, loss, accuracy = keras_common.get_optimizer_loss_and_metrics() opt, loss, accuracy = keras_common.get_optimizer_loss_and_metrics()
strategy = keras_common.get_dist_strategy() strategy = keras_common.get_dist_strategy()
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3), model = resnet56.ResNet56(input_shape=(32, 32, 3),
include_top=True, classes=cifar_main._NUM_CLASSES)
classes=cifar_main._NUM_CLASSES,
weights=None)
model.compile(loss=loss, model.compile(loss=loss,
optimizer=opt, optimizer=opt,
metrics=[accuracy], metrics=[accuracy],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment