Commit 57e7ca73 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Updating classifier_trainer MultiWorkerMirrored Strategy.

PiperOrigin-RevId: 316915450
parent e9df75ab
...@@ -119,6 +119,24 @@ python3 classifier_trainer.py \ ...@@ -119,6 +119,24 @@ python3 classifier_trainer.py \
--params_override='runtime.num_gpus=$NUM_GPUS' --params_override='runtime.num_gpus=$NUM_GPUS'
``` ```
To train on multiple hosts, each with GPUs attached using
[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
please update `runtime` section in gpu.yaml
(or override using `--params_override`) with:
```YAML
# gpu.yaml
runtime:
distribution_strategy: 'multi_worker_mirrored'
worker_hosts: '$HOST1:port,$HOST2:port'
num_gpus: $NUM_GPUS
task_index: 0
```
By having `task_index: 0` on the first host and `task_index: 1` on the second
and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
can be chosen any free port on the hosts. Only the first host will write
TensorBoard Summaries and save checkpoints.
#### On TPU: #### On TPU:
```bash ```bash
python3 classifier_trainer.py \ python3 classifier_trainer.py \
......
...@@ -235,9 +235,6 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -235,9 +235,6 @@ def initialize(params: base_configs.ExperimentConfig,
else: else:
data_format = 'channels_last' data_format = 'channels_last'
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
distribution_utils.configure_cluster(
params.runtime.worker_hosts,
params.runtime.task_index)
if params.runtime.run_eagerly: if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging # Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True) tf.config.experimental_run_functions_eagerly(True)
...@@ -296,6 +293,10 @@ def train_and_eval( ...@@ -296,6 +293,10 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit.""" """Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.') logging.info('Running train and eval.')
distribution_utils.configure_cluster(
params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset # Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy( strategy = strategy_override or distribution_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment