Unverified Commit fe58929a authored by Gustavo de Rosa's avatar Gustavo de Rosa Committed by GitHub
Browse files

Adds timeout argument to training_args to avoid socket timeouts in DDP (#18562)

* chore(training_args): Adds support for timeout argument.

* fix(training_args): Passes make style through changes.

* fix(training_args): Removes wrong docstring sentence.

* fix(training_args): Fixes timeout not being JSON serializable.

* fix(training_args_sm): Also updates timeout to timeout_delta.

* fix(training_args): Fixes PR according to suggestions.
parent ab663b22
......@@ -92,7 +92,7 @@ class SageMakerTrainingArguments(TrainingArguments):
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
torch.distributed.init_process_group(backend="smddp")
torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......@@ -111,7 +111,7 @@ class SageMakerTrainingArguments(TrainingArguments):
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......
......@@ -18,6 +18,7 @@ import math
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
......@@ -481,6 +482,11 @@ class TrainingArguments:
are also available. See the [Ray documentation](
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
more options.
ddp_timeout (`int`, *optional*, defaults to 1800):
The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
Whether to use Apple Silicon chip based `mps` device.
"""
......@@ -971,6 +977,12 @@ class TrainingArguments:
)
},
)
ddp_timeout: Optional[int] = field(
default=1800,
metadata={
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
},
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
......@@ -1291,6 +1303,13 @@ class TrainingArguments:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
return eval_batch_size
@property
def ddp_timeout_delta(self) -> timedelta:
"""
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
"""
return timedelta(seconds=self.ddp_timeout)
@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
......@@ -1358,7 +1377,9 @@ class TrainingArguments:
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
" performance."
)
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
torch.distributed.init_process_group(
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
......@@ -1369,7 +1390,7 @@ class TrainingArguments:
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group(backend="smddp")
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......@@ -1431,7 +1452,7 @@ class TrainingArguments:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment