Unverified Commit f768eb93 authored by girifb's avatar girifb Committed by GitHub
Browse files

Changing FSDP init to by pass pg validation (#619)



* Changing FSDP init to by pass pg validation for freshly minted pgs inside of init.

* Addressing Min's review comments.
Co-authored-by: default avatarGiri Anantharaman <giriman@devfair0439.h2.fair>
parent b0e6b9bd
......@@ -7,6 +7,7 @@ import contextlib
import copy
from enum import Enum, auto
import functools
import logging
from math import inf
import time
import traceback
......@@ -187,6 +188,7 @@ class FullyShardedDataParallel(nn.Module):
no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
):
init_start = time.time()
super().__init__()
self.process_group = process_group or dist.new_group()
self.rank = self.process_group.rank()
......@@ -216,7 +218,10 @@ class FullyShardedDataParallel(nn.Module):
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables
......@@ -264,6 +269,11 @@ class FullyShardedDataParallel(nn.Module):
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True
init_end = time.time()
logging.info(
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)
def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
......@@ -1541,7 +1551,7 @@ class FullyShardedDataParallel(nn.Module):
self._tstart = time.time()
if self.rank == 0:
gb_denom = 1024 ** 3
print(
logging.info(
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
)
......@@ -1627,7 +1637,7 @@ def _pre_load_state_dict_hook(
########################################################################################
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ProcessGroup = None) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
......@@ -1654,13 +1664,17 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)
) # type: ignore
pg = None
if single_rank_pg:
# No sharding with this single member group.
my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank])
else:
pg = process_group
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment