Commit deaf1951 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fixes the type of the ParameterServerStrategy.

PiperOrigin-RevId: 361587168
parent c9be6b26
...@@ -170,7 +170,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -170,7 +170,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs)) cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.compat.v1.distribute.experimental.ParameterServerStrategy() cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
raise ValueError("Unrecognized Distribution Strategy: %r" % raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy) distribution_strategy)
...@@ -181,6 +182,7 @@ def configure_cluster(worker_hosts=None, task_index=-1): ...@@ -181,6 +182,7 @@ def configure_cluster(worker_hosts=None, task_index=-1):
Args: Args:
worker_hosts: comma-separated list of worker ip:port pairs. worker_hosts: comma-separated list of worker ip:port pairs.
task_index: index of the worker.
Returns: Returns:
Number of workers in the cluster. Number of workers in the cluster.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment