Commit 8acbec71 authored by Anjali Sridhar's avatar Anjali Sridhar
Browse files

revert accidental commit

parent 180c9197
......@@ -587,12 +587,14 @@ class FullyShardedDataParallel(nn.Module):
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._orig_size = p.data.size()
if self.world_size == 1:
if not p._is_sharded:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
orig_data = p.data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment