"tests/vscode:/vscode.git/clone" did not exist on "f079e322b857714fcef1ada9e78ddc606fe51e84"
Unverified Commit 127e81c2 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Remove redundant code from TrainingArgs (#24401)

Remove redundant code
parent cd927a47
......@@ -1815,15 +1815,10 @@ class TrainingArguments:
The number of processes used in parallel.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.xrt_world_size()
if self.distributed_state is not None:
return self.distributed_state.num_processes
elif is_sagemaker_mp_enabled():
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return dist.get_world_size()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_world_size()
return 1
@property
......@@ -1832,14 +1827,10 @@ class TrainingArguments:
The index of the current process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_ordinal()
if self.distributed_state is not None:
return self.distributed_state.process_index
elif is_sagemaker_mp_enabled():
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_rank()
return 0
@property
......@@ -1848,14 +1839,11 @@ class TrainingArguments:
The index of the local process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_local_ordinal()
if self.distributed_state is not None:
return self.distributed_state.local_process_index
elif is_sagemaker_mp_enabled():
return smp.local_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return self.local_rank
return 0
@property
......@@ -1944,19 +1932,19 @@ class TrainingArguments:
"""
if is_torch_available() and self.world_size > 1:
main_process_desc = "main process"
if local:
is_main_process = self.local_process_index == 0
main_process_desc = "main local process"
main_process_desc = "main local process" if local else "main process"
if self.distributed_state is not None:
is_main_process = (
self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
)
elif is_sagemaker_mp_enabled():
is_main_process = smp.rank() == 0
else:
is_main_process = self.process_index == 0
try:
if not is_main_process:
# tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available():
xm.rendezvous(desc)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment