Commit eea6339f authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Support local state dict checkpointing for FSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/457

## Context:

The Pytorch FSDP (Fully Sharded Data Parallel) backend supports two checkpointing modes. The first one is full_state_dict mode, where each FSDP worker summons parameters from other workers to produce a global state dict that can be loaded by non-FSDP models. This mode is the desired mode for checkpointing because checkpoint structures and key names follows the default convention. It's already supported in D39228316 (https://github.com/facebookresearch/d2go/commit/02625ff83207b836df349eadc4a61eb3d4a5810c)

However, when the model is too large to fit into a single GPU memory, this approach would fail because a worker's GPU can't hold all the summoned parameters during checkpoint saving. The rescue is to use the second checkpointing mode: local_state_dict. This mode saves the sharded parameters in each GPU process locally. It can only be loaded by FSDP-wrapped models with the same distributed training settings (i.e. num processes), but it reduces the need for summoning parameters and greatly saves peak GPU memory during training

This diff enables local state dict checkpointing in d2go.

## API:

This diff supports both **saving** local state and **loading** state dict that is locally sharded. Whether to save local state is controlled by `FSDP.USE_LOCAL_STATE`. If `FSDP.USE_LOCAL_STATE=True` and we want to save `output/model_0000001.pth` as in the old pattern, the local checkpoints will be saved as:
```
- output
    - model_0000001
        - rank0.pth
        - rank1.pth
        - rank2.pth
        - rank3.pth
```
Whether to load local state, on the other hand, is controlled by the path of the checkpoint to load. If the path is a file, i.e. `output/model_final.pth`, the file will be loaded as a full state dict by all GPU processes like before. If the path is a directory, i.e. `output/model_final`, the checkpointer will attempt to load `output/model_final/rankX.pth` for rank X.

This API design enables the full combinations of loading local/full states and saving local/full states.

## Conversion to full state dict [Temporary]

Conversion from local state dict to full state dict is needed during an e2e workflow. This will be implemented in another diff

Reviewed By: wat3rBro

Differential Revision: D41861308

fbshipit-source-id: 2e01b601683d06b46f0c5517c6cff30bbcffa8f7
parent dc6fac12
_BASE_: "faster_rcnn_fbnetv3a_C4.yaml"
D2GO_DATA:
TEST:
MAX_IMAGES: 50
MODEL:
MODELING_HOOKS: ["FSDPModelingHook"]
DATASETS:
TRAIN: ("coco_2017_val",)
TEST: ("coco_2017_val",)
DATALOADER:
NUM_WORKERS: 0
FSDP:
ALGORITHM: "full"
USE_LOCAL_STATE_DICT: True
# AUTO_WRAP_POLICY: ""
STATE_DICT_CPU_OFFLOAD: False
STATE_DICT_RANK0_ONLY: True
LOAD_CKPT_TO_GPU: True
SOLVER:
IMS_PER_BATCH: 32
MAX_ITER: 5
CHECKPOINT_PERIOD: 2
AMP:
ENABLED: False
PRECISION: float16
TEST:
EVAL_PERIOD: 10000
OUTPUT_DIR: /tmp/output
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os
from typing import cast, IO
import detectron2.utils.comm as comm
import torch
......@@ -7,6 +8,7 @@ from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
......@@ -23,9 +25,17 @@ class FSDPCheckpointer(QATCheckpointer):
Add support for loading sharded optimizer states in FSDP.
.. note:: Loading optimizer states from regular checkpoints into FSDP models is currently not supported.
In general users should not resume regular training with FSDP.
In general users should not resume non-FSDP training with FSDP.
"""
if isinstance(self.model, FSDPWrapper):
if path is not None:
# loading path is a directory: sharded local state dict is used
if self.path_manager.isdir(path):
self.model.load_local_state_dict = True
path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used
else:
self.model.load_local_state_dict = False
checkpointables_iter = (
self.checkpointables.keys()
if checkpointables is None
......@@ -40,9 +50,9 @@ class FSDPCheckpointer(QATCheckpointer):
checkpoint = super().load(path, checkpointables=checkpointables_filtered)
if "optimizer" in checkpointables_iter:
self.logger.info("Loading optimizer from {} ...".format(path))
optimizer = self.checkpointables["optimizer"]
osd = checkpoint.pop("optimizer")
sharded_osd = FSDP.shard_full_optim_state_dict(osd, self.model)
self.checkpointables["optimizer"].load_state_dict(sharded_osd)
scatter_optimizer_state_dict(optimizer, osd, self.model)
if "ema_state" in checkpointables_iter:
self.logger.info("Loading ema_state from {} ...".format(path))
ema_state = checkpoint.pop("ema_state")
......@@ -59,7 +69,9 @@ class FSDPCheckpointer(QATCheckpointer):
"""
# If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDPWrapper) and not comm.is_main_process():
if not isinstance(self.model, FSDPWrapper):
if comm.is_main_process():
return super().save(name, **kwargs)
return
data = {}
......@@ -74,32 +86,62 @@ class FSDPCheckpointer(QATCheckpointer):
data[key] = obj.state_dict()
data.update(kwargs)
# Only the main process does checkpoint saving; code copied from vision/fair/fvcore/fvcore/common/checkpoint.py
if comm.is_main_process():
# If using full state dict, only the main process does checkpoint saving; Otherwise, all processes do
if self.model.use_local_state_dict:
# Main process creates directory for local saves
new_save_dir = os.path.join(self.save_dir, name)
if comm.is_main_process():
if not self.path_manager.exists(new_save_dir):
self.path_manager.mkdirs(new_save_dir)
comm.synchronize()
# Saving checkpoints
basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
if comm.is_main_process():
self.tag_last_checkpoint(name)
elif comm.is_main_process():
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with self.path_manager.open(save_file, "wb") as f:
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[typing.Any],
# IO[bytes], str, BinaryIO]` but got `Union[IO[bytes], IO[str]]`.
torch.save(data, f)
self._save_file(data, save_file)
self.tag_last_checkpoint(basename)
def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
def gather_optimizer_state_dict(optimizer, model=None):
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
"""
Get full/local optimizer state dict from an FSDP model.
"""
# FSDP: full_optim_state_dict() needs to be called by all ranks
if isinstance(model, FSDPWrapper):
if not model.use_local_state_dict:
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only)
return optimizer.state_dict()
def gather_ema_state_dict(ema_state, model):
def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper):
"""
Get EMA state dict.
For FSDP, gather local sharded EMA states from all FSDP processes and aggregate them into a FULL GLOBAL state dict
Load a full/local optimizer state dict to a FSDP model.
If using full state dict, shard and scatter the optimizer state dict before loading
"""
if isinstance(model, FSDPWrapper):
if not model.load_local_state_dict:
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
optimizer.load_state_dict(optim_state_dict)
def gather_ema_state_dict(ema_state, model: FSDPWrapper):
"""
Get full/local EMA state dict from an FSDP model.
If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict
"""
if not model.use_local_state_dict:
# Apply local ema states to the model and unshard them
with ema_state.apply_and_restore(model):
with FSDP.summon_full_params(
......@@ -110,16 +152,15 @@ def gather_ema_state_dict(ema_state, model):
):
state = EMAState.FromModel(model)
return state.state
else:
return ema_state.state_dict()
return ema_state.state_dict()
def scatter_ema_state_dict(ema_state_dict, model):
def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
"""
Load an EMA state dict to the model.
EMA state represents a FULL GLOBAL state dict and needs to be properly sharded for each FSDP process to store locally
Load a full/local EMA state dict to a FSDP model.
If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally
"""
if isinstance(model, FSDPWrapper):
if not model.load_local_state_dict:
# Store the current model state.
old_local_state = EMAState.FromModel(model)
......
#!/usr/bin/env python3
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import contextlib
import logging
from enum import Enum
from functools import partial
from typing import Callable, Iterable, Optional
from typing import Callable, Generator, Iterable, Optional
import detectron2.utils.comm as comm
import torch
import torch.nn as nn
from d2go.config import CfgNode as CN
......@@ -19,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
LocalStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
......@@ -51,8 +52,10 @@ def add_fsdp_configs(_C: CN):
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive
_C.FSDP.AUTO_WRAP_LAYER_CLS = []
# Whether to use local state dict
_C.FSDP.USE_LOCAL_STATE_DICT = False
# Whether to offload state dict to cpu
_C.FSDP.STATE_DICT_CPU_OFFLOAD = True
_C.FSDP.STATE_DICT_CPU_OFFLOAD = False
# Whether to materialize state dict on rank 0
_C.FSDP.STATE_DICT_RANK0_ONLY = True
......@@ -81,22 +84,35 @@ class FSDPWrapper(FSDP):
def __init__(
self,
model,
use_local_state_dict=False,
load_local_state_dict=False,
state_dict_cpu_offload=True,
state_dict_rank0_only=True,
**fsdp_kwargs,
):
self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict
self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only
super().__init__(model, **fsdp_kwargs)
@contextlib.contextmanager
def state_dict_type_and_config(self, is_sharded: bool) -> Generator:
if is_sharded:
state_dict_type = StateDictType.LOCAL_STATE_DICT
# only offload_to_cpu=False is supported for local state dict
state_dict_config = LocalStateDictConfig(offload_to_cpu=False)
else:
state_dict_type = StateDictType.FULL_STATE_DICT
state_dict_config = FullStateDictConfig(
offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only
)
with FSDP.state_dict_type(self, state_dict_type, state_dict_config):
yield
def state_dict(self, *args, **kwargs):
# TODO: support local state dict
# NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used
save_policy = FullStateDictConfig(
offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only
)
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT, save_policy):
with self.state_dict_type_and_config(self.use_local_state_dict):
return super().state_dict(*args, **kwargs)
def load_state_dict(
......@@ -105,7 +121,7 @@ class FSDPWrapper(FSDP):
*args,
**kwargs,
):
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT):
with self.state_dict_type_and_config(self.load_local_state_dict):
return super().load_state_dict(state_dict, *args, **kwargs)
......@@ -120,6 +136,8 @@ def build_fsdp(
param_dtype: Optional[torch.dtype] = None,
reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
device_id: Optional[int] = None,
......@@ -163,6 +181,8 @@ def build_fsdp(
"device_id": torch.cuda.current_device() if not device_id else device_id,
}
wrapper_kwargs = {
"use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict,
"state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only,
}
......@@ -194,6 +214,8 @@ class FSDPModelingHook(mh.ModelingHook):
param_dtype=precision_dtype,
reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype,
use_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
load_local_state_dict=self.cfg.FSDP.USE_LOCAL_STATE_DICT,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
device_id=torch.cuda.current_device(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment