Unverified Commit 75bbf20b authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix sagemaker DP/MP (#23681)

* Check for use_sagemaker_dp

* Add a check for is_sagemaker_mp when setting _n_gpu again. Should be last broken thing

* Try explicit check?

* Quality
parent 89159651
...@@ -3398,7 +3398,9 @@ class Trainer: ...@@ -3398,7 +3398,9 @@ class Trainer:
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors) tensors = smp_gather(tensors)
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
self.args.distributed_state is None and self.local_rank != -1
):
tensors = distributed_concat(tensors) tensors = distributed_concat(tensors)
return tensors return tensors
......
...@@ -1629,6 +1629,9 @@ class TrainingArguments: ...@@ -1629,6 +1629,9 @@ class TrainingArguments:
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
torch.cuda.set_device(device) torch.cuda.set_device(device)
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
elif self.deepspeed: elif self.deepspeed:
# Need to do similar for Accelerator init # Need to do similar for Accelerator init
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
...@@ -1653,8 +1656,9 @@ class TrainingArguments: ...@@ -1653,8 +1656,9 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
device = self.distributed_state.device device = self.distributed_state.device
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
self._n_gpu = 1 # Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.NO: elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device: if self.use_mps_device:
if not torch.backends.mps.is_available(): if not torch.backends.mps.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment