"vscode:/vscode.git/clone" did not exist on "54b352d99854417a84cce55243f5d3215f496b66"
Unverified Commit 6d3989eb authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Re-enable checkpoints for multi worker GPU strategies. (#6471)

parent f5bb2af2
...@@ -550,14 +550,10 @@ def resnet_main( ...@@ -550,14 +550,10 @@ def resnet_main(
# Creates a `RunConfig` that checkpoints every 24 hours which essentially # Creates a `RunConfig` that checkpoints every 24 hours which essentially
# results in checkpoints determined only by `epochs_between_evals`. # results in checkpoints determined only by `epochs_between_evals`.
# TODO(ayushd,yuefengz): re-enable checkpointing for multi-worker strategy.
save_checkpoints_secs = (None if distribution_strategy.__class__.__name__ in
['CollectiveAllReduceStrategy',
'MultiWorkerMirroredStrategy'] else 60*60*24)
run_config = tf.estimator.RunConfig( run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, train_distribute=distribution_strategy,
session_config=session_config, session_config=session_config,
save_checkpoints_secs=save_checkpoints_secs, save_checkpoints_secs=60*60*24,
save_checkpoints_steps=None) save_checkpoints_steps=None)
# Initializes model with all but the dense layer from pretrained ResNet. # Initializes model with all but the dense layer from pretrained ResNet.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment