"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fead3ba3867d09a2ac0e21a2e7395be5d70c02d1"
Commit 5f1ef548 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Prefetch forward

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/492

Enable prefetching the FSDP all gathers. Forward prefetch may or may not improve performance. Its effectiveness is determined by other FSDP options, such as zero2/zero3, HSDP/FSDP. Need to do a HPO sweep to figure out the best configuration.

Reviewed By: wat3rBro

Differential Revision: D43027253

fbshipit-source-id: cbf1b4bcf5b0b8301b5b9547e3c22b1f0ffc7590
parent 255313d8
...@@ -61,6 +61,8 @@ def add_fsdp_configs(_C: CN): ...@@ -61,6 +61,8 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.STATE_DICT_RANK0_ONLY = True _C.FSDP.STATE_DICT_RANK0_ONLY = True
# The ignored modules, if any # The ignored modules, if any
_C.FSDP.IGNORED_MODULES = None _C.FSDP.IGNORED_MODULES = None
# Whether to prefetch in forward pass
_C.FSDP.FORWARD_PREFETCH_OPTION = "no"
class ShardingAlgorithm(str, Enum): class ShardingAlgorithm(str, Enum):
...@@ -79,6 +81,19 @@ class ShardingAlgorithm(str, Enum): ...@@ -79,6 +81,19 @@ class ShardingAlgorithm(str, Enum):
HYBRID_SHARD_ZERO2 = "hybrid_zero2" HYBRID_SHARD_ZERO2 = "hybrid_zero2"
class ForwardPrefetchOption(str, Enum):
"""
This enum specifies the forward prefetch types to be used by FullyShardedDataParallel (FSDP).
"auto" => Use the default forward prefetch mechanism in FSDP.
"manual" => Use custom forward prefetch mechansim, implemented as training hooks.
"no" => No forward prefetch.
"""
AUTO = "auto"
MANUAL = "manual"
NO = "no"
def is_fsdp_enabled(cfg): def is_fsdp_enabled(cfg):
return "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS return "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
...@@ -161,6 +176,7 @@ def build_fsdp( ...@@ -161,6 +176,7 @@ def build_fsdp(
state_dict_cpu_offload: bool = True, state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
ignored_modules: Optional[nn.Module] = None, ignored_modules: Optional[nn.Module] = None,
forward_prefetch: bool = False,
device_id: Optional[int] = None, device_id: Optional[int] = None,
): ):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
...@@ -211,6 +227,7 @@ def build_fsdp( ...@@ -211,6 +227,7 @@ def build_fsdp(
"auto_wrap_policy": auto_wrap_policy, "auto_wrap_policy": auto_wrap_policy,
"backward_prefetch": backward_prefetch, "backward_prefetch": backward_prefetch,
"ignored_modules": ignored_modules, "ignored_modules": ignored_modules,
"forward_prefetch": forward_prefetch,
"device_id": torch.cuda.current_device() if not device_id else device_id, "device_id": torch.cuda.current_device() if not device_id else device_id,
} }
wrapper_kwargs = { wrapper_kwargs = {
...@@ -244,6 +261,9 @@ class FSDPModelingHook(mh.ModelingHook): ...@@ -244,6 +261,9 @@ class FSDPModelingHook(mh.ModelingHook):
assert mod is not None, f"Module {mod_name} cannot be found in model." assert mod is not None, f"Module {mod_name} cannot be found in model."
ignored_modules.append(mod) ignored_modules.append(mod)
forward_prefetch = (
self.cfg.FSDP.FORWARD_PREFETCH_OPTION == ForwardPrefetchOption.AUTO
)
wrapped_model = build_fsdp( wrapped_model = build_fsdp(
model, model,
sharding_algorithm=self.cfg.FSDP.ALGORITHM, sharding_algorithm=self.cfg.FSDP.ALGORITHM,
...@@ -263,6 +283,7 @@ class FSDPModelingHook(mh.ModelingHook): ...@@ -263,6 +283,7 @@ class FSDPModelingHook(mh.ModelingHook):
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD, state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY, state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
ignored_modules=ignored_modules, ignored_modules=ignored_modules,
forward_prefetch=forward_prefetch,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
) )
return wrapped_model return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment