"docs/vscode:/vscode.git/clone" did not exist on "6c08840628a22a4d53ae563d1041479649d1a8e7"
Unverified Commit 127e81c2 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Remove redundant code from TrainingArgs (#24401)

Remove redundant code
parent cd927a47
...@@ -1815,15 +1815,10 @@ class TrainingArguments: ...@@ -1815,15 +1815,10 @@ class TrainingArguments:
The number of processes used in parallel. The number of processes used in parallel.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
if self.distributed_state is not None:
if is_torch_tpu_available(): return self.distributed_state.num_processes
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return dist.get_world_size()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_world_size()
return 1 return 1
@property @property
...@@ -1832,14 +1827,10 @@ class TrainingArguments: ...@@ -1832,14 +1827,10 @@ class TrainingArguments:
The index of the current process used. The index of the current process used.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
if is_torch_tpu_available(): if self.distributed_state is not None:
return xm.get_ordinal() return self.distributed_state.process_index
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_rank()
return 0 return 0
@property @property
...@@ -1848,14 +1839,11 @@ class TrainingArguments: ...@@ -1848,14 +1839,11 @@ class TrainingArguments:
The index of the local process used. The index of the local process used.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_local_ordinal() if self.distributed_state is not None:
return self.distributed_state.local_process_index
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.local_rank() return smp.local_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return self.local_rank
return 0 return 0
@property @property
...@@ -1944,19 +1932,19 @@ class TrainingArguments: ...@@ -1944,19 +1932,19 @@ class TrainingArguments:
""" """
if is_torch_available() and self.world_size > 1: if is_torch_available() and self.world_size > 1:
main_process_desc = "main process" main_process_desc = "main local process" if local else "main process"
if local: if self.distributed_state is not None:
is_main_process = self.local_process_index == 0 is_main_process = (
main_process_desc = "main local process" self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
is_main_process = smp.rank() == 0 is_main_process = smp.rank() == 0
else:
is_main_process = self.process_index == 0
try: try:
if not is_main_process: if not is_main_process:
# tell all replicas to wait # tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment