Commit 82b56ca7 authored by Shining Sun's avatar Shining Sun
Browse files

Use strategy.scope()

parent 2c4762d0
...@@ -132,16 +132,16 @@ def run(flags_obj): ...@@ -132,16 +132,16 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) with strategy.scope():
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['categorical_accuracy'], metrics=['categorical_accuracy'])
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train']) learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
...@@ -121,16 +121,15 @@ def run(flags_obj): ...@@ -121,16 +121,15 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) with strategy.scope():
optimizer = keras_common.get_optimizer()
model.compile(loss='sparse_categorical_crossentropy', model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
optimizer=optimizer, model.compile(loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'], optimizer=optimizer,
distribute=strategy) metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment