Commit 2fbba5ff authored by Shining Sun's avatar Shining Sun
Browse files

Bug fix: None check on strategy

parent 82b56ca7
...@@ -135,7 +135,9 @@ def run(flags_obj): ...@@ -135,7 +135,9 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
with strategy.scope(): strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
......
...@@ -210,3 +210,22 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -210,3 +210,22 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
return data return data
return input_fn return input_fn
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = keras_common.DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
...@@ -124,9 +124,12 @@ def run(flags_obj): ...@@ -124,9 +124,12 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
with strategy.scope(): strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES) model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['sparse_categorical_accuracy']) metrics=['sparse_categorical_accuracy'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment