"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5b11c5dc779b7e42022d8353b1b1aa6fb9b758f3"
Commit 7f17bbf0 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

expose use_orig_params to d2go config

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/582

Expose use_orig_params for FSDP constructor to d2go config. Read more about it in the docstring of torch.distributed.fsdp.fully_sharded_data_parallel.

use_orig_params=False (default) uses FlatParameters to store flattened parameters, which saves memory by avoiding fragmentation. However, use_orig_params=True is essential for models that are partly frozen. This is because FlatParameters can only accept uniform requries_grad across the whole model

Reviewed By: wat3rBro

Differential Revision: D46917757

fbshipit-source-id: 12ebe83e6de456e37d89eaf8b257f23925a6786d
parent 60b6995d
...@@ -37,9 +37,9 @@ def add_fsdp_configs(_C: CN): ...@@ -37,9 +37,9 @@ def add_fsdp_configs(_C: CN):
# Configs for fully sharded data parallel (fsdp) # Configs for fully sharded data parallel (fsdp)
# Check out https://pytorch.org/docs/stable/fsdp.html # Check out https://pytorch.org/docs/stable/fsdp.html
# and docstring of torch.distributed.fsdp.fully_sharded_data_parallel # and docstring of torch.distributed.fsdp.fully_sharded_data_parallel
# See docstring of CpuOffload and BackwardPrefetch in torch.distributed.fsdp.fully_sharded_data_parallel
_C.FSDP.CPU_OFFLOAD = False _C.FSDP.CPU_OFFLOAD = False
_C.FSDP.BACKWARD_PREFETCH = True _C.FSDP.BACKWARD_PREFETCH = True
_C.FSDP.USE_ORIG_PARAMS = False
# Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap # Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy" _C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy"
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4) _C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
...@@ -176,6 +176,7 @@ def build_fsdp( ...@@ -176,6 +176,7 @@ def build_fsdp(
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
ignored_modules: Optional[nn.Module] = None, ignored_modules: Optional[nn.Module] = None,
forward_prefetch: bool = False, forward_prefetch: bool = False,
use_orig_params: bool = False,
device_id: Optional[int] = None, device_id: Optional[int] = None,
): ):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
...@@ -227,6 +228,7 @@ def build_fsdp( ...@@ -227,6 +228,7 @@ def build_fsdp(
"backward_prefetch": backward_prefetch, "backward_prefetch": backward_prefetch,
"ignored_modules": ignored_modules, "ignored_modules": ignored_modules,
"forward_prefetch": forward_prefetch, "forward_prefetch": forward_prefetch,
"use_orig_params": use_orig_params,
"device_id": torch.cuda.current_device() if not device_id else device_id, "device_id": torch.cuda.current_device() if not device_id else device_id,
} }
# default to using use_local_state_dict if state_dict_type is None # default to using use_local_state_dict if state_dict_type is None
...@@ -304,6 +306,7 @@ class FSDPModelingHook(ModelingHook): ...@@ -304,6 +306,7 @@ class FSDPModelingHook(ModelingHook):
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY, state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
ignored_modules=ignored_modules, ignored_modules=ignored_modules,
forward_prefetch=forward_prefetch, forward_prefetch=forward_prefetch,
use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
) )
return wrapped_model return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment