Commit abdad994 authored by Geet Sethi's avatar Geet Sethi Committed by Facebook GitHub Bot
Browse files

distributed FSDP model initialization

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/656

Enable distributed FSDP model initialization. This iteratively moves and shards the model on GPU to allow for the training of models greater than single GPU HBM capacity and which cannot be instantiated multiple times on a single host.

The flow is as follows:
1. Rank 0 will init the whole model on CPU using existing code paths, while all other ranks init an 'empty' model using fake tensors.
2. Once this is complete and initialization moves to FSDP, distributed init will traverse the model 'bottom-up', transferring all params/buffers from rank 0 to all other ranks, while simultaneously wrapping modules in FSDP whenever possible (based on the specified config). Thus modules are sharded (and memory usage distributed) at the first possible instance using the existing FSDP api/implementation.

Reviewed By: XiaoliangDai

Differential Revision: D54287718

fbshipit-source-id: 16d63d78065d1fca0c6baf7a385f666a4e1b2a5f
parent 102305a5
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
...@@ -13,6 +14,8 @@ from d2go.registry.builtin import META_ARCH_REGISTRY ...@@ -13,6 +14,8 @@ from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.misc import _log_api_usage from d2go.utils.misc import _log_api_usage
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
logger = logging.getLogger(__name__)
@dataclass @dataclass
class D2GoModelBuildResult: class D2GoModelBuildResult:
...@@ -54,7 +57,35 @@ def build_meta_arch(cfg): ...@@ -54,7 +57,35 @@ def build_meta_arch(cfg):
def build_d2go_model( def build_d2go_model(
cfg: CfgNode, cfg: CfgNode,
) -> D2GoModelBuildResult: ) -> D2GoModelBuildResult:
model = build_meta_arch(cfg) # NOTE distributed initialization path (using FSDP) for large models
if (
hasattr(cfg.MODEL, "MODELING_HOOKS")
and "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
and hasattr(cfg, "FSDP")
and hasattr(cfg.FSDP, "DISTRIBUTED_INIT")
and cfg.FSDP.DISTRIBUTED_INIT
):
logger.info("Using distributed initialization path.")
import torch.distributed as dist
if dist.is_initialized():
from d2go.trainer.fsdp import CpuOverrideMode
from torch._subclasses import FakeTensorMode
# NOTE (global) rank 0 will build the whole model on cpu
# other ranks will build the model on fake tensors
if dist.get_rank() == 0:
with CpuOverrideMode():
model = build_meta_arch(cfg)
else:
with FakeTensorMode(allow_non_fake_inputs=True):
model = build_meta_arch(cfg)
else:
raise RuntimeError(
"torch.distributed is not initialized. cannot process with distributed init."
)
else:
model = build_meta_arch(cfg)
modeling_hooks: List[mh.ModelingHook] = [] modeling_hooks: List[mh.ModelingHook] = []
# apply modeling hooks # apply modeling hooks
# some custom projects bypass d2go's default config so may not have the # some custom projects bypass d2go's default config so may not have the
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import contextlib import contextlib
import logging import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional from typing import Any, Dict, Generator, List, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -25,11 +25,21 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -25,11 +25,21 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
StateDictType, StateDictType,
) )
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CpuOverrideMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
if kwargs is not None:
if "device" in kwargs:
kwargs["device"] = torch.device("cpu")
return func(*args, **kwargs)
def add_fsdp_configs(_C: CN): def add_fsdp_configs(_C: CN):
_C.FSDP = CN() _C.FSDP = CN()
_C.FSDP.ALGORITHM = "grad_optim" # 'grad_optim', 'full', 'hybrid', 'hybrid_zero2' _C.FSDP.ALGORITHM = "grad_optim" # 'grad_optim', 'full', 'hybrid', 'hybrid_zero2'
...@@ -60,6 +70,8 @@ def add_fsdp_configs(_C: CN): ...@@ -60,6 +70,8 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.FORWARD_PREFETCH_OPTION = "no" _C.FSDP.FORWARD_PREFETCH_OPTION = "no"
# if False, this allows the CPU thread to schedule all-gathers without any extra synchronization # if False, this allows the CPU thread to schedule all-gathers without any extra synchronization
_C.FSDP.LIMIT_ALL_GATHERS = False _C.FSDP.LIMIT_ALL_GATHERS = False
# flag for distributed FSDP model initialization
_C.FSDP.DISTRIBUTED_INIT = False
class ShardingAlgorithm(str, Enum): class ShardingAlgorithm(str, Enum):
...@@ -99,6 +111,86 @@ def get_grad_scaler(cfg): ...@@ -99,6 +111,86 @@ def get_grad_scaler(cfg):
return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler() return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler()
def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]):
import torch.distributed as dist
modules_to_fsdp: Tuple = tuple(fsdp_kwargs["auto_wrap_policy"]._module_classes)
del fsdp_kwargs["auto_wrap_policy"]
modules_not_to_fsdp: List = fsdp_kwargs["ignored_modules"]
device_id = fsdp_kwargs["device_id"]
cuda_device = torch.device(f"cuda:{device_id}")
# postorder traversal (i.e. bottom-up)
visited_modules: Set[nn.Module] = {root_module}
def postorder_fsdp_wrap(
module: nn.Module,
module_name: str,
fqn: str,
parent_module: Optional[nn.Module],
ignore_branch: bool,
):
rank = dist.get_rank()
# don't traverse branches of specified ignored modules
if module in modules_not_to_fsdp:
ignore_branch = True
for child_name, child_module in module.named_children():
if child_module not in visited_modules:
visited_modules.add(child_module)
postorder_fsdp_wrap(
child_module,
child_name,
f"{fqn}.{child_name}",
module,
ignore_branch,
)
logger.info(
f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}"
)
# regardless of wrapping, we need to transfer all
# module params and buffers to device, and if not rank 0,
# need to retreive data from rank 0
with torch.no_grad():
if rank != 0:
with no_dispatch():
for name, param in module.named_parameters(recurse=False):
setattr(
module,
name,
torch.nn.Parameter(
torch.empty_like(param, device=cuda_device)
),
)
for name, buffer in module.named_buffers(recurse=False):
setattr(
module, name, torch.empty_like(buffer, device=cuda_device)
)
else:
for _, param in module.named_parameters(recurse=False):
param.data = param.to(cuda_device)
for _, buffer in module.named_buffers(recurse=False):
buffer.data = buffer.to(cuda_device)
for _, param in module.named_parameters(recurse=False):
dist.broadcast(param, 0)
for _, buffer in module.named_buffers(recurse=False):
dist.broadcast(buffer, 0)
# if module is marked for FSDP, wrap it
# AND if not in ignored branch
if not ignore_branch and isinstance(module, modules_to_fsdp):
logger.info(
f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}"
)
setattr(parent_module, module_name, FSDP(module, **fsdp_kwargs))
logger.info(
f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}"
)
postorder_fsdp_wrap(root_module, "root", "root", None, False)
class FSDPWrapper(FSDP): class FSDPWrapper(FSDP):
def __init__( def __init__(
self, self,
...@@ -108,6 +200,7 @@ class FSDPWrapper(FSDP): ...@@ -108,6 +200,7 @@ class FSDPWrapper(FSDP):
amp_autocast_dtype: Optional[torch.dtype] = None, amp_autocast_dtype: Optional[torch.dtype] = None,
state_dict_cpu_offload: bool = True, state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
distributed_init: bool = False,
**fsdp_kwargs, **fsdp_kwargs,
): ):
self.precision = amp_autocast_dtype self.precision = amp_autocast_dtype
...@@ -115,7 +208,14 @@ class FSDPWrapper(FSDP): ...@@ -115,7 +208,14 @@ class FSDPWrapper(FSDP):
self.load_state_dict_type = load_state_dict_type self.load_state_dict_type = load_state_dict_type
self.offload_to_cpu = state_dict_cpu_offload self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only self.rank0_only = state_dict_rank0_only
self.distributed_init = distributed_init
if self.distributed_init:
# NOTE traverse and apply all non-root level FSDP
# and then wrap root level FSDP
bottom_up_nested_fsdp(model, fsdp_kwargs)
super().__init__(model, **fsdp_kwargs) super().__init__(model, **fsdp_kwargs)
logger.info(f"FSDP Wrapped model architecture: {self}")
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
# Wrap forward() in autocast if mixed precision is enabled # Wrap forward() in autocast if mixed precision is enabled
...@@ -181,6 +281,7 @@ def build_fsdp( ...@@ -181,6 +281,7 @@ def build_fsdp(
use_orig_params: bool = False, use_orig_params: bool = False,
device_id: Optional[int] = None, device_id: Optional[int] = None,
limit_all_gathers: bool = False, limit_all_gathers: bool = False,
distributed_init: bool = False,
): ):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
...@@ -256,6 +357,7 @@ def build_fsdp( ...@@ -256,6 +357,7 @@ def build_fsdp(
"load_state_dict_type": _load_state_dict_type, "load_state_dict_type": _load_state_dict_type,
"state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only, "state_dict_rank0_only": state_dict_rank0_only,
"distributed_init": distributed_init,
} }
return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs) return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs)
...@@ -313,6 +415,7 @@ class FSDPModelingHook(ModelingHook): ...@@ -313,6 +415,7 @@ class FSDPModelingHook(ModelingHook):
use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS, use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS, limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS,
distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT,
) )
return wrapped_model return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment