Unverified Commit 73bf5964 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] help pure fp16 FSDP init a bit (#1068)

* [minor] [FSDP] add a better for pure fp16

* [minor] [wrap] add a flag to help fsdp pure fp16 wrapping
parent 454537d1
...@@ -475,6 +475,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -475,6 +475,12 @@ class FullyShardedDataParallel(nn.Module):
self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params) self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
self._param_name_groups = param_name_groups self._param_name_groups = param_name_groups
# Check to see if the mixed precision setting is correct.
if self.compute_dtype is torch.float16 and self.mixed_precision is False:
for p in self.params:
if p.dtype is not torch.float16:
raise ValueError("Expecting FP16 param type in pure FP16 mode.")
# Shard module parameters in place # Shard module parameters in place
self._shard_parameters_() self._shard_parameters_()
......
...@@ -184,6 +184,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: ...@@ -184,6 +184,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
assert isinstance(module_overrides, dict) assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides} wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None assert ConfigAutoWrap.wrapper_cls is not None
if ConfigAutoWrap.move_module_cuda_half:
module = module.cuda().half()
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides) return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module return module
...@@ -236,6 +238,7 @@ class ConfigAutoWrap: ...@@ -236,6 +238,7 @@ class ConfigAutoWrap:
""" """
in_autowrap_context: bool = False # Context flag in_autowrap_context: bool = False # Context flag
move_module_cuda_half: bool = False # A flag to control the wrap() function.
wrapper_cls: Optional[Callable] = None # The wrapper class wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
...@@ -252,6 +255,9 @@ class ConfigAutoWrap: ...@@ -252,6 +255,9 @@ class ConfigAutoWrap:
) )
ConfigAutoWrap.in_autowrap_context = True ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context. # Get and save the wrapper cls for the context.
if "move_module_cuda_half" in kwargs.keys():
ConfigAutoWrap.move_module_cuda_half = cast(bool, kwargs["move_module_cuda_half"])
del kwargs["move_module_cuda_half"]
assert "wrapper_cls" in kwargs.keys() assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"] del kwargs["wrapper_cls"]
...@@ -262,6 +268,7 @@ class ConfigAutoWrap: ...@@ -262,6 +268,7 @@ class ConfigAutoWrap:
@staticmethod @staticmethod
def disable_autowrap_context() -> None: def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.move_module_cuda_half = False
ConfigAutoWrap.wrapper_cls = None ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {} ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.auto_wrap_policy = None ConfigAutoWrap.auto_wrap_policy = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment