Unverified Commit f768eb93 authored by girifb's avatar girifb Committed by GitHub
Browse files

Changing FSDP init to by pass pg validation (#619)



* Changing FSDP init to by pass pg validation for freshly minted pgs inside of init.

* Addressing Min's review comments.
Co-authored-by: default avatarGiri Anantharaman <giriman@devfair0439.h2.fair>
parent b0e6b9bd
...@@ -7,6 +7,7 @@ import contextlib ...@@ -7,6 +7,7 @@ import contextlib
import copy import copy
from enum import Enum, auto from enum import Enum, auto
import functools import functools
import logging
from math import inf from math import inf
import time import time
import traceback import traceback
...@@ -187,6 +188,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -187,6 +188,7 @@ class FullyShardedDataParallel(nn.Module):
no_broadcast_optim_state: Optional[bool] = False, no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None, state_dict_device: Optional[torch.device] = None,
): ):
init_start = time.time()
super().__init__() super().__init__()
self.process_group = process_group or dist.new_group() self.process_group = process_group or dist.new_group()
self.rank = self.process_group.rank() self.rank = self.process_group.rank()
...@@ -216,7 +218,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -216,7 +218,10 @@ class FullyShardedDataParallel(nn.Module):
if self.cpu_offload and not self.mixed_precision: if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True") raise ValueError("cpu_offload requires mixed_precision=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group) validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module) enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables # Only handle params which are not already sharded. This enables
...@@ -264,6 +269,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -264,6 +269,11 @@ class FullyShardedDataParallel(nn.Module):
# full params. This defaults to True, but may be set to False if the # full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict(). # user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True self._return_full_state_dict = True
init_end = time.time()
logging.info(
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)
def get_gradient_predivide_factor(self, world_size: int) -> int: def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1 factor = 1
...@@ -1541,7 +1551,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1541,7 +1551,7 @@ class FullyShardedDataParallel(nn.Module):
self._tstart = time.time() self._tstart = time.time()
if self.rank == 0: if self.rank == 0:
gb_denom = 1024 ** 3 gb_denom = 1024 ** 3
print( logging.info(
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}" f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
) )
...@@ -1627,7 +1637,7 @@ def _pre_load_state_dict_hook( ...@@ -1627,7 +1637,7 @@ def _pre_load_state_dict_hook(
######################################################################################## ########################################################################################
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module: def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ProcessGroup = None) -> nn.Module:
""" """
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening. to sync BN is used and the outer FSDP is flattening.
...@@ -1654,13 +1664,17 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module: ...@@ -1654,13 +1664,17 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
if recurse: if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else: else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)
) # type: ignore
pg = None pg = None
if single_rank_pg: if single_rank_pg:
# No sharding with this single member group. # No sharding with this single member group.
my_rank = dist.get_rank() my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank]) pg = dist.new_group(ranks=[my_rank])
else:
pg = process_group
fsdp_config = { fsdp_config = {
"wrapper_cls": FullyShardedDataParallel, "wrapper_cls": FullyShardedDataParallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment