Commit 180c9197 authored by Anjali Sridhar's avatar Anjali Sridhar
Browse files

simplify condiiton for readability

parent b09ddb2d
......@@ -587,19 +587,20 @@ class FullyShardedDataParallel(nn.Module):
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._orig_size = p.data.size()
if not p._is_sharded:
if self.world_size == 1:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
p._is_sharded = True
assert len(self.numel_padded_per_param) == len(self.params)
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment