Unverified Commit ee67e7ad authored by Carolyn Wang's avatar Carolyn Wang Committed by GitHub
Browse files

patch for smddp import (#18244)

* add import

* format
parent 68097dcc
......@@ -1333,6 +1333,8 @@ class TrainingArguments:
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group(backend="smddp")
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment