"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5002b8e4a2cc8d5ac26939228c61e3ecf4c589e1"
Unverified Commit 8db64367 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

Fix wrong xpu device in DistributedType.MULTI_XPU mode (#28386)

* remove elif xpu

* remove redudant code
parent 690fe73f
......@@ -1844,11 +1844,6 @@ class TrainingArguments:
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
device = torch.device("xpu:0")
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
......@@ -1877,12 +1872,6 @@ class TrainingArguments:
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
# Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
if "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self._n_gpu = 1
device = torch.device("xpu:0")
torch.xpu.set_device(device)
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment