Commit abf0ca0c authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Make AMP compatible with FSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/458

Make AMP compatible with FSDP. FSDP does not depend on the torch AMP module and implements its own MixedPrecision module. This MixedPrecision module directly saves additional copy of weights in lower precision and use run these tensors in mixed precision training. This is very different from AMP, which automatically casts tensors to lower precision upon tensor operations.

This diff solves some compatibility bugs between AMP and FSDP with 2 changes:
1. Use "never_wrap_policy" as the default dummy autowrap policy.
FSDP Mixed Precision doesn't work with Batchnorm layers. This is because FSDP and other resources like NVidia apex highly discourage running lower precision for batchnorm: https://github.com/pytorch/pytorch/issues/75478. We need to use some autowrap policy in order to let FSDP surpass batchnorm layers in constructing mixed precision.
2. Wrap FSDPWrapper.forward() with autocast()
FSDP Mixed Precision uses lower-precision tensors in computation, which could raise type mismatch error when amp.autocast() is not enabled, like in eval. Thus, we wrap FSDP forward() with autocast()

Reviewed By: wat3rBro

Differential Revision: D41328834

fbshipit-source-id: 18cf94c4ad8d9422ffd3bb335873cd29ac987ae9
parent 5ad2d57e
......@@ -48,7 +48,7 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.CPU_OFFLOAD = False
_C.FSDP.BACKWARD_PREFETCH = True
# Find autowrap policy at D2GO_FSDP_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.FSDP.AUTO_WRAP_POLICY = ""
_C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy"
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive
_C.FSDP.AUTO_WRAP_LAYER_CLS = []
......@@ -84,18 +84,30 @@ class FSDPWrapper(FSDP):
def __init__(
self,
model,
use_local_state_dict=False,
load_local_state_dict=False,
state_dict_cpu_offload=True,
state_dict_rank0_only=True,
amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
**fsdp_kwargs,
):
self.precision = amp_autocast_dtype
self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict
self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only
super().__init__(model, **fsdp_kwargs)
def forward(self, *args, **kwargs):
# Wrap forward() in autocast if mixed precision is enabled
if self.precision is not None and not torch.is_autocast_enabled():
from torch.cuda.amp import autocast
with autocast(dtype=self.precision):
return super().forward(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@contextlib.contextmanager
def state_dict_type_and_config(self, is_sharded: bool) -> Generator:
if is_sharded:
......@@ -136,6 +148,7 @@ def build_fsdp(
param_dtype: Optional[torch.dtype] = None,
reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True,
......@@ -181,6 +194,7 @@ def build_fsdp(
"device_id": torch.cuda.current_device() if not device_id else device_id,
}
wrapper_kwargs = {
"amp_autocast_dtype": amp_autocast_dtype,
"use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict,
"state_dict_cpu_offload": state_dict_cpu_offload,
......@@ -214,6 +228,7 @@ class FSDPModelingHook(mh.ModelingHook):
param_dtype=precision_dtype,
reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype,
amp_autocast_dtype=precision_dtype,
use_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
......@@ -247,6 +262,18 @@ def get_module_class_from_name(module, name):
return module_class
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def never_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Don't wrap any child module, only wrap the root
"""
def never_wrap(*args, **kwargs):
return False
return never_wrap
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment