Unverified Commit 105f6507 authored by foreveronehundred's avatar foreveronehundred Committed by GitHub
Browse files

Add a new arg, "force_broadcast_object", to OSS __init__ (#942)

* [FSDP] Add an arg for FSDP __init__

Add an arg, disable_reshard_on_root, for FSDP __init__ to handle the following issue
https://github.com/facebookresearch/fairscale/issues/878


For some cases (models wrapped by autowrap), the parameters (of root modules) needs to be sharded, and reshard_after_forward should not be set to False.
"disable_reshard_on_root" is for users to choose whether to force reshard_after_forward of root modules to be False or not.

* Update fully_sharded_data_parallel.py

Modified the description of the feature to explain more clear.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Update the comments for disable_reshard_on_root
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Modified the comments

Modified the comments of disable_reshard_on_root

* Add a new argument for OSS __init__

Add a new argument for OSS __init__ to force the OSS to apply "_broadcast_object" for rebuilding the sharded optimizer. For more details, please see https://github.com/facebookresearch/fairscale/issues/937



* Remove redundant space

Remove redundant space
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 40e7450f
......@@ -107,7 +107,10 @@ class OSS(Optimizer):
Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy.
force_broadcast_object (bool):
If True, '_broadcast_object' will be used for rebuilding the sharded optimizer.
If False, whether to use '_broadcast_object' or 'dist.broadcast_object_list' will be determined by GPU capabilities.
This feature is needed since some newer GPUs still get some memory issues when applying dist.broadcast_object_list.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
meaning that all the parameters which `require_grad` are handled differently. This is
......@@ -129,6 +132,7 @@ class OSS(Optimizer):
group: Optional[Any] = None,
broadcast_buffer_size: int = -1,
broadcast_fp16: bool = False,
force_broadcast_object: bool = False,
**default: Any,
):
......@@ -156,6 +160,7 @@ class OSS(Optimizer):
self._local_to_global_rank = [get_global_rank(self.group, i) for i in range(self.world_size)]
self.broadcast_fp16 = broadcast_fp16
self.force_broadcast_object = force_broadcast_object
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu")
......@@ -334,7 +339,7 @@ class OSS(Optimizer):
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=dist_device)
)
if _gpu_capabilities_older_than_50():
if self.force_broadcast_object or _gpu_capabilities_older_than_50():
_broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
)
......@@ -347,7 +352,7 @@ class OSS(Optimizer):
)
else:
# Fetch the optim state from the other replicas
if _gpu_capabilities_older_than_50():
if self.force_broadcast_object or _gpu_capabilities_older_than_50():
replica_state = _broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment