Unverified Commit b08843cf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a `parallel_mode` property to TrainingArguments (#8877)

* Add a `distributed_env` property to TrainingArguments

* Change name

* Address comment
parent 7c10dd22
...@@ -465,6 +465,27 @@ class TrainingArguments: ...@@ -465,6 +465,27 @@ class TrainingArguments:
""" """
return self._setup_devices[1] return self._setup_devices[1]
@property
@torch_required
def parallel_mode(self):
"""
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
:obj:`torch.nn.DistributedDataParallel`).
- :obj:`ParallelMode.TPU`: several TPU cores.
"""
if is_torch_tpu_available():
return ParallelMode.TPU
elif self.local_rank != -1:
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED
else:
return ParallelMode.NOT_PARALLEL
def to_dict(self): def to_dict(self):
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). Serializes this instance while replace `Enum` by their values (for JSON serialization support).
...@@ -493,3 +514,10 @@ class TrainingArguments: ...@@ -493,3 +514,10 @@ class TrainingArguments:
valid_types.append(torch.Tensor) valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel"
NOT_DISTRIBUTED = "not_distributed"
DISTRIBUTED = "distributed"
TPU = "tpu"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment