Unverified Commit d9f36130 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[FSDP] Consolidate cpu_adam optimizer state dict (#607)

parent 1141528e
......@@ -810,7 +810,6 @@ def bench_mpi(args):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
......
......@@ -1416,6 +1416,9 @@ class FullyShardedDataParallel(nn.Module):
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
if torch.is_tensor(t):
t = t.to(self.compute_device)
if ou.is_singleton_tensor(t):
if singleton_buffer is None:
singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment