Unverified Commit 22a0dd2e authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fixing error when using sharded ddp (#18435)

parent 5096a654
...@@ -1344,9 +1344,8 @@ class Trainer: ...@@ -1344,9 +1344,8 @@ class Trainer:
reshard_after_forward=zero_3, reshard_after_forward=zero_3,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
).to(self.args.device) ).to(self.args.device)
# Distributed training using PyTorch FSDP # Distributed training using PyTorch FSDP
if self.fsdp is not None: elif self.fsdp is not None:
# PyTorch FSDP! # PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
...@@ -1394,7 +1393,6 @@ class Trainer: ...@@ -1394,7 +1393,6 @@ class Trainer:
) )
if FSDPOption.OFFLOAD not in self.args.fsdp: if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device) model.to(self.args.device)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment