Commit d8734049 authored by Xiaoliang Dai's avatar Xiaoliang Dai Committed by Facebook GitHub Bot
Browse files

allow setting limit_all_gather in fsdp

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/598

allow setting limit_all_gather in fsdp.  This enables faster training, as discussed in S351092

Reviewed By: Sekunde

Differential Revision: D47603555

fbshipit-source-id: 48d672fd5cce1763da91d8b801a8cb81630bfcdc
parent 361c5457
...@@ -58,6 +58,8 @@ def add_fsdp_configs(_C: CN): ...@@ -58,6 +58,8 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.IGNORED_MODULES = None _C.FSDP.IGNORED_MODULES = None
# Whether to prefetch in forward pass # Whether to prefetch in forward pass
_C.FSDP.FORWARD_PREFETCH_OPTION = "no" _C.FSDP.FORWARD_PREFETCH_OPTION = "no"
# if False, this allows the CPU thread to schedule all-gathers without any extra synchronization
_C.FSDP.LIMIT_ALL_GATHERS = False
class ShardingAlgorithm(str, Enum): class ShardingAlgorithm(str, Enum):
...@@ -178,6 +180,7 @@ def build_fsdp( ...@@ -178,6 +180,7 @@ def build_fsdp(
forward_prefetch: bool = False, forward_prefetch: bool = False,
use_orig_params: bool = False, use_orig_params: bool = False,
device_id: Optional[int] = None, device_id: Optional[int] = None,
limit_all_gathers: bool = False,
): ):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
...@@ -230,6 +233,7 @@ def build_fsdp( ...@@ -230,6 +233,7 @@ def build_fsdp(
"forward_prefetch": forward_prefetch, "forward_prefetch": forward_prefetch,
"use_orig_params": use_orig_params, "use_orig_params": use_orig_params,
"device_id": torch.cuda.current_device() if not device_id else device_id, "device_id": torch.cuda.current_device() if not device_id else device_id,
"limit_all_gathers": limit_all_gathers,
} }
# default to using use_local_state_dict if state_dict_type is None # default to using use_local_state_dict if state_dict_type is None
if not state_dict_type: if not state_dict_type:
...@@ -308,6 +312,7 @@ class FSDPModelingHook(ModelingHook): ...@@ -308,6 +312,7 @@ class FSDPModelingHook(ModelingHook):
forward_prefetch=forward_prefetch, forward_prefetch=forward_prefetch,
use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS, use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS,
) )
return wrapped_model return wrapped_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment