"tutorials/vscode:/vscode.git/clone" did not exist on "eb1acecd0d9057d116fa8c8687123600a5065557"
Commit 3024bde6 authored by Soroush Radpour's avatar Soroush Radpour Committed by Yuefeng Zhou
Browse files

Add the option to run Keras resnet model on multiple workers. (#6368)

parent cf304238
...@@ -144,7 +144,8 @@ def run(flags_obj): ...@@ -144,7 +144,8 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster())
strategy_scope = keras_common.get_strategy_scope(strategy) strategy_scope = keras_common.get_strategy_scope(strategy)
......
...@@ -229,6 +229,8 @@ def configure_cluster(worker_hosts=None, task_index=-1): ...@@ -229,6 +229,8 @@ def configure_cluster(worker_hosts=None, task_index=-1):
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}')) tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
if tf_config: if tf_config:
num_workers = len(tf_config['cluster']['worker']) num_workers = len(tf_config['cluster']['worker'])
if tf_config['cluster'].get('chief', None):
num_workers += 1
elif worker_hosts: elif worker_hosts:
workers = worker_hosts.split(',') workers = worker_hosts.split(',')
num_workers = len(workers) num_workers = len(workers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment