Commit eea6339f authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Support local state dict checkpointing for FSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/457

## Context:

The Pytorch FSDP (Fully Sharded Data Parallel) backend supports two checkpointing modes. The first one is full_state_dict mode, where each FSDP worker summons parameters from other workers to produce a global state dict that can be loaded by non-FSDP models. This mode is the desired mode for checkpointing because checkpoint structures and key names follows the default convention. It's already supported in D39228316 (https://github.com/facebookresearch/d2go/commit/02625ff83207b836df349eadc4a61eb3d4a5810c)

However, when the model is too large to fit into a single GPU memory, this approach would fail because a worker's GPU can't hold all the summoned parameters during checkpoint saving. The rescue is to use the second checkpointing mode: local_state_dict. This mode saves the sharded parameters in each GPU process locally. It can only be loaded by FSDP-wrapped models with the same distributed training settings (i.e. num processes), but it reduces the need for summoning parameters and greatly saves peak GPU memory during training

This diff enables local state dict checkpointing in d2go.

## API:

This diff supports both **saving** local state and **loading** state dict that is locally sharded. Whether to save local state is controlled by `FSDP.USE_LOCAL_STATE`. If `FSDP.USE_LOCAL_STATE=True` and we want to save `output/model_0000001.pth` as in the old pattern, the local checkpoints will be saved as:
```
- output
    - model_0000001
        - rank0.pth
        - rank1.pth
        - rank2.pth
        - rank3.pth
```
Whether to load local state, on the other hand, is controlled by the path of the checkpoint to load. If the path is a file, i.e. `output/model_final.pth`, the file will be loaded as a full state dict by all GPU processes like before. If the path is a directory, i.e. `output/model_final`, the checkpointer will attempt to load `output/model_final/rankX.pth` for rank X.

This API design enables the full combinations of loading local/full states and saving local/full states.

## Conversion to full state dict [Temporary]

Conversion from local state dict to full state dict is needed during an e2e workflow. This will be implemented in another diff

Reviewed By: wat3rBro

Differential Revision: D41861308

fbshipit-source-id: 2e01b601683d06b46f0c5517c6cff30bbcffa8f7
parent dc6fac12
_BASE_: "faster_rcnn_fbnetv3a_C4.yaml"
D2GO_DATA:
TEST:
MAX_IMAGES: 50
MODEL:
MODELING_HOOKS: ["FSDPModelingHook"]
DATASETS:
TRAIN: ("coco_2017_val",)
TEST: ("coco_2017_val",)
DATALOADER:
NUM_WORKERS: 0
FSDP:
ALGORITHM: "full"
USE_LOCAL_STATE_DICT: True
# AUTO_WRAP_POLICY: ""
STATE_DICT_CPU_OFFLOAD: False
STATE_DICT_RANK0_ONLY: True
LOAD_CKPT_TO_GPU: True
SOLVER:
IMS_PER_BATCH: 32
MAX_ITER: 5
CHECKPOINT_PERIOD: 2
AMP:
ENABLED: False
PRECISION: float16
TEST:
EVAL_PERIOD: 10000
OUTPUT_DIR: /tmp/output
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os import os
from typing import cast, IO
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -7,6 +8,7 @@ from d2go.modeling.ema import EMAState ...@@ -7,6 +8,7 @@ from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper from d2go.trainer.fsdp import FSDPWrapper
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
) )
...@@ -23,9 +25,17 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -23,9 +25,17 @@ class FSDPCheckpointer(QATCheckpointer):
Add support for loading sharded optimizer states in FSDP. Add support for loading sharded optimizer states in FSDP.
.. note:: Loading optimizer states from regular checkpoints into FSDP models is currently not supported. .. note:: Loading optimizer states from regular checkpoints into FSDP models is currently not supported.
In general users should not resume regular training with FSDP. In general users should not resume non-FSDP training with FSDP.
""" """
if isinstance(self.model, FSDPWrapper): if isinstance(self.model, FSDPWrapper):
if path is not None:
# loading path is a directory: sharded local state dict is used
if self.path_manager.isdir(path):
self.model.load_local_state_dict = True
path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used
else:
self.model.load_local_state_dict = False
checkpointables_iter = ( checkpointables_iter = (
self.checkpointables.keys() self.checkpointables.keys()
if checkpointables is None if checkpointables is None
...@@ -40,9 +50,9 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -40,9 +50,9 @@ class FSDPCheckpointer(QATCheckpointer):
checkpoint = super().load(path, checkpointables=checkpointables_filtered) checkpoint = super().load(path, checkpointables=checkpointables_filtered)
if "optimizer" in checkpointables_iter: if "optimizer" in checkpointables_iter:
self.logger.info("Loading optimizer from {} ...".format(path)) self.logger.info("Loading optimizer from {} ...".format(path))
optimizer = self.checkpointables["optimizer"]
osd = checkpoint.pop("optimizer") osd = checkpoint.pop("optimizer")
sharded_osd = FSDP.shard_full_optim_state_dict(osd, self.model) scatter_optimizer_state_dict(optimizer, osd, self.model)
self.checkpointables["optimizer"].load_state_dict(sharded_osd)
if "ema_state" in checkpointables_iter: if "ema_state" in checkpointables_iter:
self.logger.info("Loading ema_state from {} ...".format(path)) self.logger.info("Loading ema_state from {} ...".format(path))
ema_state = checkpoint.pop("ema_state") ema_state = checkpoint.pop("ema_state")
...@@ -59,7 +69,9 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -59,7 +69,9 @@ class FSDPCheckpointer(QATCheckpointer):
""" """
# If no sharding, only the main process enters the saving codepath; # If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks # otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDPWrapper) and not comm.is_main_process(): if not isinstance(self.model, FSDPWrapper):
if comm.is_main_process():
return super().save(name, **kwargs)
return return
data = {} data = {}
...@@ -74,32 +86,62 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -74,32 +86,62 @@ class FSDPCheckpointer(QATCheckpointer):
data[key] = obj.state_dict() data[key] = obj.state_dict()
data.update(kwargs) data.update(kwargs)
# Only the main process does checkpoint saving; code copied from vision/fair/fvcore/fvcore/common/checkpoint.py # If using full state dict, only the main process does checkpoint saving; Otherwise, all processes do
if self.model.use_local_state_dict:
# Main process creates directory for local saves
new_save_dir = os.path.join(self.save_dir, name)
if comm.is_main_process(): if comm.is_main_process():
if not self.path_manager.exists(new_save_dir):
self.path_manager.mkdirs(new_save_dir)
comm.synchronize()
# Saving checkpoints
basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
if comm.is_main_process():
self.tag_last_checkpoint(name)
elif comm.is_main_process():
basename = "{}.pth".format(name) basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename) save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file)) self._save_file(data, save_file)
with self.path_manager.open(save_file, "wb") as f:
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[typing.Any],
# IO[bytes], str, BinaryIO]` but got `Union[IO[bytes], IO[str]]`.
torch.save(data, f)
self.tag_last_checkpoint(basename) self.tag_last_checkpoint(basename)
def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
def gather_optimizer_state_dict(optimizer, model=None): def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
"""
Get full/local optimizer state dict from an FSDP model.
"""
# FSDP: full_optim_state_dict() needs to be called by all ranks # FSDP: full_optim_state_dict() needs to be called by all ranks
if isinstance(model, FSDPWrapper): if not model.use_local_state_dict:
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only) return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only)
return optimizer.state_dict() return optimizer.state_dict()
def gather_ema_state_dict(ema_state, model): def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper):
"""
Load a full/local optimizer state dict to a FSDP model.
If using full state dict, shard and scatter the optimizer state dict before loading
"""
if not model.load_local_state_dict:
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
optimizer.load_state_dict(optim_state_dict)
def gather_ema_state_dict(ema_state, model: FSDPWrapper):
""" """
Get EMA state dict. Get full/local EMA state dict from an FSDP model.
For FSDP, gather local sharded EMA states from all FSDP processes and aggregate them into a FULL GLOBAL state dict If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict
""" """
if isinstance(model, FSDPWrapper): if not model.use_local_state_dict:
# Apply local ema states to the model and unshard them # Apply local ema states to the model and unshard them
with ema_state.apply_and_restore(model): with ema_state.apply_and_restore(model):
with FSDP.summon_full_params( with FSDP.summon_full_params(
...@@ -110,16 +152,15 @@ def gather_ema_state_dict(ema_state, model): ...@@ -110,16 +152,15 @@ def gather_ema_state_dict(ema_state, model):
): ):
state = EMAState.FromModel(model) state = EMAState.FromModel(model)
return state.state return state.state
else:
return ema_state.state_dict() return ema_state.state_dict()
def scatter_ema_state_dict(ema_state_dict, model): def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
""" """
Load an EMA state dict to the model. Load a full/local EMA state dict to a FSDP model.
EMA state represents a FULL GLOBAL state dict and needs to be properly sharded for each FSDP process to store locally If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally
""" """
if isinstance(model, FSDPWrapper): if not model.load_local_state_dict:
# Store the current model state. # Store the current model state.
old_local_state = EMAState.FromModel(model) old_local_state = EMAState.FromModel(model)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import contextlib
import logging import logging
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Callable, Iterable, Optional from typing import Callable, Generator, Iterable, Optional
import detectron2.utils.comm as comm
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
...@@ -19,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -19,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload, CPUOffload,
FullStateDictConfig, FullStateDictConfig,
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
LocalStateDictConfig,
MixedPrecision, MixedPrecision,
ShardingStrategy, ShardingStrategy,
StateDictType, StateDictType,
...@@ -51,8 +52,10 @@ def add_fsdp_configs(_C: CN): ...@@ -51,8 +52,10 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4) _C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive # A list of layer cls names to wrap, case sensitive
_C.FSDP.AUTO_WRAP_LAYER_CLS = [] _C.FSDP.AUTO_WRAP_LAYER_CLS = []
# Whether to use local state dict
_C.FSDP.USE_LOCAL_STATE_DICT = False
# Whether to offload state dict to cpu # Whether to offload state dict to cpu
_C.FSDP.STATE_DICT_CPU_OFFLOAD = True _C.FSDP.STATE_DICT_CPU_OFFLOAD = False
# Whether to materialize state dict on rank 0 # Whether to materialize state dict on rank 0
_C.FSDP.STATE_DICT_RANK0_ONLY = True _C.FSDP.STATE_DICT_RANK0_ONLY = True
...@@ -81,22 +84,35 @@ class FSDPWrapper(FSDP): ...@@ -81,22 +84,35 @@ class FSDPWrapper(FSDP):
def __init__( def __init__(
self, self,
model, model,
use_local_state_dict=False,
load_local_state_dict=False,
state_dict_cpu_offload=True, state_dict_cpu_offload=True,
state_dict_rank0_only=True, state_dict_rank0_only=True,
**fsdp_kwargs, **fsdp_kwargs,
): ):
self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict
self.offload_to_cpu = state_dict_cpu_offload self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only self.rank0_only = state_dict_rank0_only
super().__init__(model, **fsdp_kwargs) super().__init__(model, **fsdp_kwargs)
def state_dict(self, *args, **kwargs): @contextlib.contextmanager
# TODO: support local state dict def state_dict_type_and_config(self, is_sharded: bool) -> Generator:
# NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used if is_sharded:
save_policy = FullStateDictConfig( state_dict_type = StateDictType.LOCAL_STATE_DICT
# only offload_to_cpu=False is supported for local state dict
state_dict_config = LocalStateDictConfig(offload_to_cpu=False)
else:
state_dict_type = StateDictType.FULL_STATE_DICT
state_dict_config = FullStateDictConfig(
offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only
) )
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT, save_policy): with FSDP.state_dict_type(self, state_dict_type, state_dict_config):
yield
def state_dict(self, *args, **kwargs):
# NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used
with self.state_dict_type_and_config(self.use_local_state_dict):
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
def load_state_dict( def load_state_dict(
...@@ -105,7 +121,7 @@ class FSDPWrapper(FSDP): ...@@ -105,7 +121,7 @@ class FSDPWrapper(FSDP):
*args, *args,
**kwargs, **kwargs,
): ):
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT): with self.state_dict_type_and_config(self.load_local_state_dict):
return super().load_state_dict(state_dict, *args, **kwargs) return super().load_state_dict(state_dict, *args, **kwargs)
...@@ -120,6 +136,8 @@ def build_fsdp( ...@@ -120,6 +136,8 @@ def build_fsdp(
param_dtype: Optional[torch.dtype] = None, param_dtype: Optional[torch.dtype] = None,
reduce_dtype: Optional[torch.dtype] = None, reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None, buffer_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True, state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
device_id: Optional[int] = None, device_id: Optional[int] = None,
...@@ -163,6 +181,8 @@ def build_fsdp( ...@@ -163,6 +181,8 @@ def build_fsdp(
"device_id": torch.cuda.current_device() if not device_id else device_id, "device_id": torch.cuda.current_device() if not device_id else device_id,
} }
wrapper_kwargs = { wrapper_kwargs = {
"use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict,
"state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only, "state_dict_rank0_only": state_dict_rank0_only,
} }
...@@ -194,6 +214,8 @@ class FSDPModelingHook(mh.ModelingHook): ...@@ -194,6 +214,8 @@ class FSDPModelingHook(mh.ModelingHook):
param_dtype=precision_dtype, param_dtype=precision_dtype,
reduce_dtype=precision_dtype, reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype, buffer_dtype=precision_dtype,
use_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD, state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY, state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment