Commit 5f1ef548 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Prefetch forward

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/492

Enable prefetching the FSDP all gathers. Forward prefetch may or may not improve performance. Its effectiveness is determined by other FSDP options, such as zero2/zero3, HSDP/FSDP. Need to do a HPO sweep to figure out the best configuration.

Reviewed By: wat3rBro

Differential Revision: D43027253

fbshipit-source-id: cbf1b4bcf5b0b8301b5b9547e3c22b1f0ffc7590
parent 255313d8
......@@ -61,6 +61,8 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.STATE_DICT_RANK0_ONLY = True
# The ignored modules, if any
_C.FSDP.IGNORED_MODULES = None
# Whether to prefetch in forward pass
_C.FSDP.FORWARD_PREFETCH_OPTION = "no"
class ShardingAlgorithm(str, Enum):
......@@ -79,6 +81,19 @@ class ShardingAlgorithm(str, Enum):
HYBRID_SHARD_ZERO2 = "hybrid_zero2"
class ForwardPrefetchOption(str, Enum):
"""
This enum specifies the forward prefetch types to be used by FullyShardedDataParallel (FSDP).
"auto" => Use the default forward prefetch mechanism in FSDP.
"manual" => Use custom forward prefetch mechansim, implemented as training hooks.
"no" => No forward prefetch.
"""
AUTO = "auto"
MANUAL = "manual"
NO = "no"
def is_fsdp_enabled(cfg):
return "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
......@@ -161,6 +176,7 @@ def build_fsdp(
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
ignored_modules: Optional[nn.Module] = None,
forward_prefetch: bool = False,
device_id: Optional[int] = None,
):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
......@@ -211,6 +227,7 @@ def build_fsdp(
"auto_wrap_policy": auto_wrap_policy,
"backward_prefetch": backward_prefetch,
"ignored_modules": ignored_modules,
"forward_prefetch": forward_prefetch,
"device_id": torch.cuda.current_device() if not device_id else device_id,
}
wrapper_kwargs = {
......@@ -244,6 +261,9 @@ class FSDPModelingHook(mh.ModelingHook):
assert mod is not None, f"Module {mod_name} cannot be found in model."
ignored_modules.append(mod)
forward_prefetch = (
self.cfg.FSDP.FORWARD_PREFETCH_OPTION == ForwardPrefetchOption.AUTO
)
wrapped_model = build_fsdp(
model,
sharding_algorithm=self.cfg.FSDP.ALGORITHM,
......@@ -263,6 +283,7 @@ class FSDPModelingHook(mh.ModelingHook):
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
ignored_modules=ignored_modules,
forward_prefetch=forward_prefetch,
device_id=torch.cuda.current_device(),
)
return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment