Unverified Commit ee67e7ad authored by Carolyn Wang's avatar Carolyn Wang Committed by GitHub
Browse files

patch for smddp import (#18244)

* add import

* format
parent 68097dcc
...@@ -1333,6 +1333,8 @@ class TrainingArguments: ...@@ -1333,6 +1333,8 @@ class TrainingArguments:
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group(backend="smddp") dist.init_process_group(backend="smddp")
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment