Commit 7ef9d897 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Ignore modules

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/470

Enable ignore FSDP modules. Those modules will not be put in FSDP. It is useful in the diffusion model, where the CLIP model is not used in training. Thus, it is OK to have a separate copy in each GPU. It reduces the CLIP execution time from 63ms to 48ms (15ms reduction). This is mostly because it is a CPU bounded module and in each FSDP block, some code is injected. In addition, it also reduces the FSDP all gather time before the CLIP execution from 56ms to 7ms (49ms reduction).

In total, this change may reduce the CLIP runtime from 119ms to 64ms (63ms reduction)

This feature is controlled by this flag:
    IGNORED_MODULES: ["clip_model"]

Reviewed By: newstzpz

Differential Revision: D42910383

fbshipit-source-id: dc4c12254d45ac45d88329feb63a26ec4ae04aef
parent c4c512ce
......@@ -13,6 +13,7 @@ from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import parse_precision_from_string
from detectron2.utils.registry import Registry
from torch.ao.pruning import fqn_to_module
from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
......@@ -58,6 +59,8 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.STATE_DICT_CPU_OFFLOAD = False
# Whether to materialize state dict on rank 0
_C.FSDP.STATE_DICT_RANK0_ONLY = True
# The ignored modules, if any
_C.FSDP.IGNORED_MODULES = None
class ShardingAlgorithm(str, Enum):
......@@ -157,6 +160,7 @@ def build_fsdp(
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
ignored_modules: Optional[nn.Module] = None,
device_id: Optional[int] = None,
):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
......@@ -206,6 +210,7 @@ def build_fsdp(
"mixed_precision": mixed_precision,
"auto_wrap_policy": auto_wrap_policy,
"backward_prefetch": backward_prefetch,
"ignored_modules": ignored_modules,
"device_id": torch.cuda.current_device() if not device_id else device_id,
}
wrapper_kwargs = {
......@@ -230,6 +235,15 @@ class FSDPModelingHook(mh.ModelingHook):
if self.cfg.SOLVER.AMP.ENABLED
else None
)
ignored_modules = None
if isinstance(self.cfg.FSDP.IGNORED_MODULES, list):
ignored_modules = []
for mod_name in self.cfg.FSDP.IGNORED_MODULES:
mod = fqn_to_module(model, mod_name)
assert mod is not None, f"Module {mod_name} cannot be found in model."
ignored_modules.append(mod)
wrapped_model = build_fsdp(
model,
sharding_algorithm=self.cfg.FSDP.ALGORITHM,
......@@ -248,6 +262,7 @@ class FSDPModelingHook(mh.ModelingHook):
load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
ignored_modules=ignored_modules,
device_id=torch.cuda.current_device(),
)
return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment