Unverified Commit 105f6507 authored by foreveronehundred's avatar foreveronehundred Committed by GitHub
Browse files

Add a new arg, "force_broadcast_object", to OSS __init__ (#942)

* [FSDP] Add an arg for FSDP __init__

Add an arg, disable_reshard_on_root, for FSDP __init__ to handle the following issue
https://github.com/facebookresearch/fairscale/issues/878


For some cases (models wrapped by autowrap), the parameters (of root modules) needs to be sharded, and reshard_after_forward should not be set to False.
"disable_reshard_on_root" is for users to choose whether to force reshard_after_forward of root modules to be False or not.

* Update fully_sharded_data_parallel.py

Modified the description of the feature to explain more clear.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Update the comments for disable_reshard_on_root
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Modified the comments

Modified the comments of disable_reshard_on_root

* Add a new argument for OSS __init__

Add a new argument for OSS __init__ to force the OSS to apply "_broadcast_object" for rebuilding the sharded optimizer. For more details, please see https://github.com/facebookresearch/fairscale/issues/937



* Remove redundant space

Remove redundant space
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 40e7450f
...@@ -107,7 +107,10 @@ class OSS(Optimizer): ...@@ -107,7 +107,10 @@ class OSS(Optimizer):
Compress the model shards in fp16 before sharing them in between ranks. Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy. degradation in terms of accuracy.
force_broadcast_object (bool):
If True, '_broadcast_object' will be used for rebuilding the sharded optimizer.
If False, whether to use '_broadcast_object' or 'dist.broadcast_object_list' will be determined by GPU capabilities.
This feature is needed since some newer GPUs still get some memory issues when applying dist.broadcast_object_list.
.. warning: the communication patterns that OSS use depend on the "trainability" graph, .. warning: the communication patterns that OSS use depend on the "trainability" graph,
meaning that all the parameters which `require_grad` are handled differently. This is meaning that all the parameters which `require_grad` are handled differently. This is
...@@ -129,6 +132,7 @@ class OSS(Optimizer): ...@@ -129,6 +132,7 @@ class OSS(Optimizer):
group: Optional[Any] = None, group: Optional[Any] = None,
broadcast_buffer_size: int = -1, broadcast_buffer_size: int = -1,
broadcast_fp16: bool = False, broadcast_fp16: bool = False,
force_broadcast_object: bool = False,
**default: Any, **default: Any,
): ):
...@@ -156,6 +160,7 @@ class OSS(Optimizer): ...@@ -156,6 +160,7 @@ class OSS(Optimizer):
self._local_to_global_rank = [get_global_rank(self.group, i) for i in range(self.world_size)] self._local_to_global_rank = [get_global_rank(self.group, i) for i in range(self.world_size)]
self.broadcast_fp16 = broadcast_fp16 self.broadcast_fp16 = broadcast_fp16
self.force_broadcast_object = force_broadcast_object
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {} self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu") self._default_device = torch.device("cpu")
...@@ -334,7 +339,7 @@ class OSS(Optimizer): ...@@ -334,7 +339,7 @@ class OSS(Optimizer):
if should_send_state if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=dist_device) else torch.tensor([0], dtype=torch.uint8, device=dist_device)
) )
if _gpu_capabilities_older_than_50(): if self.force_broadcast_object or _gpu_capabilities_older_than_50():
_broadcast_object( _broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
) )
...@@ -347,7 +352,7 @@ class OSS(Optimizer): ...@@ -347,7 +352,7 @@ class OSS(Optimizer):
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
if _gpu_capabilities_older_than_50(): if self.force_broadcast_object or _gpu_capabilities_older_than_50():
replica_state = _broadcast_object( replica_state = _broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device), torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank], src_rank=self._local_to_global_rank[rank],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment