"vscode:/vscode.git/clone" did not exist on "cc93136e6a166566fc6f0502c67aa99a94673db3"
Commit 82b56ca7 authored by Shining Sun's avatar Shining Sun
Browse files

Use strategy.scope()

parent 2c4762d0
......@@ -132,16 +132,16 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
with strategy.scope():
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'],
distribute=strategy)
metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
......@@ -121,16 +121,15 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
with strategy.scope():
optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'],
distribute=strategy)
metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment