Commit fbc1c2e8 authored by David Yan's avatar David Yan Committed by Facebook GitHub Bot
Browse files

Add support for FSDP SHARDED_STATE_DICT in D2Go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/512

Currently, when saving and loading checkpoints for FSDP-wrapped modules, we are saving and loading using `StateDictType.LOCAL_STATE_DICT`, where the state_dict becomes essentially a single flat tensor under the `_flat_param` key (or some other layer-specific key for flat weights). This means that
1. It's impossible to load weights directly from checkpoints, for example in notebooks
2. Converting from a local to a global checkpoint requires running a special workflow (https://fburl.com/code/6yqa4ldb) that occupies the same number of GPUs as was used during training

This diff adds an option, `FSDP.STATE_DICT_TYPE`, which allows selection of the type of state dict to save (local, sharded, full). In sharded mode, with AIF checkpointing, we are able to have the benefit of allowing local loading of state dicts in minutes with any number of GPUs, in notebooks and elsewhere.

Note: for backwards compatibility, `CFG.FSDP.use_local_state_dict` and `CFG.FSDP.load_local_state_dict` still need to work when the new config parameter (`CFG.FSDP.state_dict_type`) is not set. Also, it's used to signify that local/sharded state dicts need to be converted to a full state dict when loading. This functionality can be deprecated when everyone migrates to AIF checkpointing with sharded dicts.

Reviewed By: YanjunChen329

Differential Revision: D43840887

fbshipit-source-id: d112f7b7ad97ba82fd5bf1da986b95ad7fc61c42
parent d912e9f8
...@@ -12,6 +12,7 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank ...@@ -12,6 +12,7 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
StateDictType,
) )
...@@ -47,10 +48,11 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -47,10 +48,11 @@ class FSDPCheckpointer(QATCheckpointer):
if isinstance(self.model, FSDPWrapper): if isinstance(self.model, FSDPWrapper):
load_path = path load_path = path
if path: if path:
# loading path is a directory: sharded local state dict is used # loading path is a directory: local or sharded state dict is used
# TODO(T148056077): Make loading sharded state dicts more elegant
if self.path_manager.isdir(path): if self.path_manager.isdir(path):
self.logger.info( self.logger.info(
"[FSDPCheckpointer] Loading from local checkpoint ..." "[FSDPCheckpointer] Loading from local or sharded checkpoint ..."
) )
self.model.load_local_state_dict = True self.model.load_local_state_dict = True
load_path = os.path.join(path, f"rank{comm.get_rank()}.pth") load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
...@@ -123,7 +125,6 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -123,7 +125,6 @@ class FSDPCheckpointer(QATCheckpointer):
if comm.is_main_process(): if comm.is_main_process():
return super().save(name, **kwargs) return super().save(name, **kwargs)
return return
data = {} data = {}
# FSDP: model.state_dict() needs to be called by all ranks before saving # FSDP: model.state_dict() needs to be called by all ranks before saving
data["model"] = self.model.state_dict() data["model"] = self.model.state_dict()
...@@ -137,7 +138,7 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -137,7 +138,7 @@ class FSDPCheckpointer(QATCheckpointer):
data.update(kwargs) data.update(kwargs)
# If using full state dict, only the main process does checkpoint saving; Otherwise, all processes do # If using full state dict, only the main process does checkpoint saving; Otherwise, all processes do
if self.model.use_local_state_dict: if self.model.state_dict_type != StateDictType.FULL_STATE_DICT:
# Main process creates directory for local saves # Main process creates directory for local saves
new_save_dir = os.path.join(self.save_dir, name) new_save_dir = os.path.join(self.save_dir, name)
if comm.is_main_process(): if comm.is_main_process():
...@@ -181,8 +182,10 @@ def gather_optimizer_state_dict(optimizer, model: FSDPWrapper): ...@@ -181,8 +182,10 @@ def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
Get full/local optimizer state dict from an FSDP model. Get full/local optimizer state dict from an FSDP model.
""" """
# FSDP: full_optim_state_dict() needs to be called by all ranks # FSDP: full_optim_state_dict() needs to be called by all ranks
if not model.use_local_state_dict: if model.state_dict_type == StateDictType.FULL_STATE_DICT:
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only) return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
return FSDP.sharded_optim_state_dict(model, optimizer)
return optimizer.state_dict() return optimizer.state_dict()
...@@ -191,8 +194,12 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper ...@@ -191,8 +194,12 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
Load a full/local optimizer state dict to a FSDP model. Load a full/local optimizer state dict to a FSDP model.
If using full state dict, shard and scatter the optimizer state dict before loading If using full state dict, shard and scatter the optimizer state dict before loading
""" """
if not model.load_local_state_dict: if model.state_dict_type == StateDictType.FULL_STATE_DICT:
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model) optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
optim_state_dict, model
)
optimizer.load_state_dict(optim_state_dict) optimizer.load_state_dict(optim_state_dict)
...@@ -201,7 +208,7 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper): ...@@ -201,7 +208,7 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper):
Get full/local EMA state dict from an FSDP model. Get full/local EMA state dict from an FSDP model.
If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict
""" """
if not model.use_local_state_dict: if model.state_dict_type == StateDictType.FULL_STATE_DICT:
# Apply local ema states to the model and unshard them # Apply local ema states to the model and unshard them
with ema_state.apply_and_restore(model): with ema_state.apply_and_restore(model):
with FSDP.summon_full_params( with FSDP.summon_full_params(
...@@ -220,7 +227,7 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper): ...@@ -220,7 +227,7 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
Load a full/local EMA state dict to a FSDP model. Load a full/local EMA state dict to a FSDP model.
If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally
""" """
if not model.load_local_state_dict: if model.state_dict_type == StateDictType.FULL_STATE_DICT:
# Store the current model state. # Store the current model state.
old_local_state = EMAState.FromModel(model) old_local_state = EMAState.FromModel(model)
......
...@@ -22,6 +22,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -22,6 +22,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
LocalStateDictConfig, LocalStateDictConfig,
MixedPrecision, MixedPrecision,
ShardedStateDictConfig,
ShardingStrategy, ShardingStrategy,
StateDictType, StateDictType,
) )
...@@ -53,8 +54,11 @@ def add_fsdp_configs(_C: CN): ...@@ -53,8 +54,11 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4) _C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive # A list of layer cls names to wrap, case sensitive
_C.FSDP.AUTO_WRAP_LAYER_CLS = [] _C.FSDP.AUTO_WRAP_LAYER_CLS = []
# Whether to use local state dict # Whether to use local state dict -- superseded by STATE_DICT_TYPE
_C.FSDP.USE_LOCAL_STATE_DICT = False _C.FSDP.USE_LOCAL_STATE_DICT = False
# State dict type to use when calling FSDPWrapper.state_dict() (used when saving).
# If None, defaults to checking the value of USE_LOCAL_STATE_DICT
_C.FSDP.STATE_DICT_TYPE = None
# Whether to offload state dict to cpu # Whether to offload state dict to cpu
_C.FSDP.STATE_DICT_CPU_OFFLOAD = False _C.FSDP.STATE_DICT_CPU_OFFLOAD = False
# Whether to materialize state dict on rank 0 # Whether to materialize state dict on rank 0
...@@ -106,6 +110,8 @@ class FSDPWrapper(FSDP): ...@@ -106,6 +110,8 @@ class FSDPWrapper(FSDP):
def __init__( def __init__(
self, self,
model, model,
state_dict_type: StateDictType,
load_state_dict_type: StateDictType,
amp_autocast_dtype: Optional[torch.dtype] = None, amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False, use_local_state_dict: bool = False,
load_local_state_dict: bool = False, load_local_state_dict: bool = False,
...@@ -116,6 +122,8 @@ class FSDPWrapper(FSDP): ...@@ -116,6 +122,8 @@ class FSDPWrapper(FSDP):
self.precision = amp_autocast_dtype self.precision = amp_autocast_dtype
self.use_local_state_dict = use_local_state_dict self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict self.load_local_state_dict = load_local_state_dict
self.state_dict_type = state_dict_type
self.load_state_dict_type = load_state_dict_type
self.offload_to_cpu = state_dict_cpu_offload self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only self.rank0_only = state_dict_rank0_only
super().__init__(model, **fsdp_kwargs) super().__init__(model, **fsdp_kwargs)
...@@ -131,22 +139,24 @@ class FSDPWrapper(FSDP): ...@@ -131,22 +139,24 @@ class FSDPWrapper(FSDP):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def state_dict_type_and_config(self, is_sharded: bool) -> Generator: def state_dict_type_and_config(self, state_dict_type: StateDictType) -> Generator:
if is_sharded: if state_dict_type == StateDictType.LOCAL_STATE_DICT:
state_dict_type = StateDictType.LOCAL_STATE_DICT
# only offload_to_cpu=False is supported for local state dict # only offload_to_cpu=False is supported for local state dict
state_dict_config = LocalStateDictConfig(offload_to_cpu=False) state_dict_config = LocalStateDictConfig(offload_to_cpu=False)
else: elif state_dict_type == StateDictType.FULL_STATE_DICT:
state_dict_type = StateDictType.FULL_STATE_DICT
state_dict_config = FullStateDictConfig( state_dict_config = FullStateDictConfig(
offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only
) )
else:
state_dict_config = ShardedStateDictConfig(
offload_to_cpu=self.offload_to_cpu
)
with FSDP.state_dict_type(self, state_dict_type, state_dict_config): with FSDP.state_dict_type(self, state_dict_type, state_dict_config):
yield yield
def state_dict(self, *args, **kwargs): def state_dict(self, *args, **kwargs):
# NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used # NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used
with self.state_dict_type_and_config(self.use_local_state_dict): with self.state_dict_type_and_config(self.state_dict_type):
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
def load_state_dict( def load_state_dict(
...@@ -155,7 +165,7 @@ class FSDPWrapper(FSDP): ...@@ -155,7 +165,7 @@ class FSDPWrapper(FSDP):
*args, *args,
**kwargs, **kwargs,
): ):
with self.state_dict_type_and_config(self.load_local_state_dict): with self.state_dict_type_and_config(self.load_state_dict_type):
return super().load_state_dict(state_dict, *args, **kwargs) return super().load_state_dict(state_dict, *args, **kwargs)
...@@ -173,6 +183,7 @@ def build_fsdp( ...@@ -173,6 +183,7 @@ def build_fsdp(
amp_autocast_dtype: Optional[torch.dtype] = None, amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False, use_local_state_dict: bool = False,
load_local_state_dict: bool = False, load_local_state_dict: bool = False,
state_dict_type: Optional[StateDictType] = None,
state_dict_cpu_offload: bool = True, state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
ignored_modules: Optional[nn.Module] = None, ignored_modules: Optional[nn.Module] = None,
...@@ -230,10 +241,27 @@ def build_fsdp( ...@@ -230,10 +241,27 @@ def build_fsdp(
"forward_prefetch": forward_prefetch, "forward_prefetch": forward_prefetch,
"device_id": torch.cuda.current_device() if not device_id else device_id, "device_id": torch.cuda.current_device() if not device_id else device_id,
} }
# default to using use_local_state_dict if state_dict_type is None
if not state_dict_type:
_state_dict_type = (
StateDictType.LOCAL_STATE_DICT
if use_local_state_dict
else StateDictType.FULL_STATE_DICT
)
else:
_state_dict_type = state_dict_type
# load_state_dict_type defaults to load_local_state_dict
_load_state_dict_type = (
StateDictType.LOCAL_STATE_DICT
if load_local_state_dict
else StateDictType.FULL_STATE_DICT
)
wrapper_kwargs = { wrapper_kwargs = {
"amp_autocast_dtype": amp_autocast_dtype, "amp_autocast_dtype": amp_autocast_dtype,
"use_local_state_dict": use_local_state_dict, "use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict, "load_local_state_dict": load_local_state_dict,
"state_dict_type": _state_dict_type,
"load_state_dict_type": _load_state_dict_type,
"state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only, "state_dict_rank0_only": state_dict_rank0_only,
} }
...@@ -264,6 +292,11 @@ class FSDPModelingHook(mh.ModelingHook): ...@@ -264,6 +292,11 @@ class FSDPModelingHook(mh.ModelingHook):
forward_prefetch = ( forward_prefetch = (
self.cfg.FSDP.FORWARD_PREFETCH_OPTION == ForwardPrefetchOption.AUTO self.cfg.FSDP.FORWARD_PREFETCH_OPTION == ForwardPrefetchOption.AUTO
) )
_state_dict_type = (
StateDictType[self.cfg.FSDP.STATE_DICT_TYPE]
if self.cfg.FSDP.STATE_DICT_TYPE
else None
)
wrapped_model = build_fsdp( wrapped_model = build_fsdp(
model, model,
sharding_algorithm=self.cfg.FSDP.ALGORITHM, sharding_algorithm=self.cfg.FSDP.ALGORITHM,
...@@ -280,6 +313,7 @@ class FSDPModelingHook(mh.ModelingHook): ...@@ -280,6 +313,7 @@ class FSDPModelingHook(mh.ModelingHook):
amp_autocast_dtype=precision_dtype, amp_autocast_dtype=precision_dtype,
use_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT, use_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT, load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
state_dict_type=_state_dict_type,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD, state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY, state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
ignored_modules=ignored_modules, ignored_modules=ignored_modules,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment