Unverified Commit 0ff0717b authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Merge pull request #5987 from tensorflow/scope

Use the new distribution strategy API with keras (creating the model inside strategy.scope() instead of passing it to model.compile)
parents 2c4762d0 2fbba5ff
......@@ -132,16 +132,18 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
strategy_scope = keras_common.get_strategy_scope(strategy)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'],
distribute=strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
......@@ -210,3 +210,22 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
return data
return input_fn
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = keras_common.DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
......@@ -121,16 +121,18 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
strategy_scope = keras_common.get_strategy_scope(strategy)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'],
distribute=strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment