Unverified Commit 73bf5964 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] help pure fp16 FSDP init a bit (#1068)

* [minor] [FSDP] add a better for pure fp16

* [minor] [wrap] add a flag to help fsdp pure fp16 wrapping
parent 454537d1
......@@ -475,6 +475,12 @@ class FullyShardedDataParallel(nn.Module):
self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
self._param_name_groups = param_name_groups
# Check to see if the mixed precision setting is correct.
if self.compute_dtype is torch.float16 and self.mixed_precision is False:
for p in self.params:
if p.dtype is not torch.float16:
raise ValueError("Expecting FP16 param type in pure FP16 mode.")
# Shard module parameters in place
self._shard_parameters_()
......
......@@ -184,6 +184,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None
if ConfigAutoWrap.move_module_cuda_half:
module = module.cuda().half()
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module
......@@ -236,6 +238,7 @@ class ConfigAutoWrap:
"""
in_autowrap_context: bool = False # Context flag
move_module_cuda_half: bool = False # A flag to control the wrap() function.
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
......@@ -252,6 +255,9 @@ class ConfigAutoWrap:
)
ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
if "move_module_cuda_half" in kwargs.keys():
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
del kwargs["move_module_cuda_half"]
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
......@@ -262,6 +268,7 @@ class ConfigAutoWrap:
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.move_module_cuda_half = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.auto_wrap_policy = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment