Unverified Commit a8aad0ec authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fixup multigpu local_rank (#22869)

Fixup multigpu tests
parent 06bab003
...@@ -1537,9 +1537,7 @@ class TrainingArguments: ...@@ -1537,9 +1537,7 @@ class TrainingArguments:
) )
if self.no_cuda: if self.no_cuda:
self.distributed_state = PartialState(cpu=True) self.distributed_state = PartialState(cpu=True)
device = self.distributed_state.device
self._n_gpu = 0 self._n_gpu = 0
self.local_rank = self.distributed_state.local_process_index
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank() local_rank = smp.local_rank()
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
...@@ -1548,11 +1546,12 @@ class TrainingArguments: ...@@ -1548,11 +1546,12 @@ class TrainingArguments:
elif self.deepspeed: elif self.deepspeed:
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
self._n_gpu = 1 self._n_gpu = 1
device = self.distributed_state.device
else: else:
self.distributed_state = PartialState(backend=self.xpu_backend) self.distributed_state = PartialState(backend=self.xpu_backend)
device = self.distributed_state.device
self._n_gpu = 1 self._n_gpu = 1
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
if ( if (
torch.distributed.is_available() torch.distributed.is_available()
and torch.distributed.is_initialized() and torch.distributed.is_initialized()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment