"tests/python/common/test_sparse_ops-csr.py" did not exist on "ab2bd1f13dd982be501b5874e9d4eb2217b068bc"
Commit dc6fac12 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Rewrite FSDP wrapping as modeling hook

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/440

Move FSDP wrapping to runner.build_model by rewriting it as a modeling hook

**Motivation**
When a model is too large to run inference on a single GPU, it requires using FSDP with local checkpointing mode to save peak GPU memory. However, in eval_pytorch workflow (train_net with eval-only), models are evaluated without being wrapped by FSDP. This may cause OOM errors for the reasons above. Thus, it may be a better practice to wrap model with FSDP during `runner.build_model(cfg)`, so evaluation can also be run in the same FSDP setting as in training.

This diff moves FSDP wrapping to `runner.build_model(cfg)` by rewriting it as a modeling hook.

**API changes**
* Users need to append `"FSDPModelingHook"` to `MODEL.MODELING_HOOKS` to enable FSDP.
* `FSDP.ALGORITHM` can only be `full` or `grad_optim`

**Note**
It's not possible to unwrap an FSDP model back to the normal model, so FSDPModelingHook.unapply() can't be implemented

Reviewed By: wat3rBro

Differential Revision: D41416917

fbshipit-source-id: f3fc72d574cc6ccbe0d238e48c575926ba5b4d06
parent 7b2ba6cb
......@@ -59,7 +59,7 @@ class FSDPCheckpointer(QATCheckpointer):
"""
# If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDP) and not comm.is_main_process():
if not isinstance(self.model, FSDPWrapper) and not comm.is_main_process():
return
data = {}
......
......@@ -517,7 +517,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
_get_model_with_abnormal_checker(model),
data_loader,
optimizer,
grad_scaler=get_grad_scaler(cfg.FSDP.ALGORITHM),
grad_scaler=get_grad_scaler(cfg),
precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False
),
......
......@@ -7,9 +7,11 @@ from typing import Callable, Iterable, Optional
import detectron2.utils.comm as comm
import torch
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import parse_precision_from_string
from detectron2.engine.defaults import create_ddp_model
from detectron2.utils.registry import Registry
from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import (
......@@ -36,7 +38,7 @@ D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY")
def add_fsdp_configs(_C: CN):
_C.FSDP = CN()
_C.FSDP.ALGORITHM = "" # 'grad_optim', 'full' or ''
_C.FSDP.ALGORITHM = "" # 'grad_optim' or 'full'
# Configs for fully sharded data parallel (fsdp)
# Check out https://pytorch.org/docs/stable/fsdp.html
......@@ -56,36 +58,23 @@ def add_fsdp_configs(_C: CN):
class ShardingAlgorithm(str, Enum):
"""
This enum specifies the sharding algorithm to be used by FullyShardedDataParallel (FSDP).
It matches the strings used in D2Go config with the enum class :class:`ShardingStrategy` used by Pytorch FSDP module:
"grad_optim" => ShardingAlgorithm.SHARD_GRAD_OP => ShardingStrategy.SHARD_GRAD_OP
"full" => ShardingAlgorithm.FULL_SHARD => ShardingStrategy.FULL_SHARD
"""
SHARD_GRAD_OP = "grad_optim"
FULL_SHARD = "full"
NO_SHARD = ""
@classmethod
def is_valid(cls, key):
return key in {item.value for item in ShardingAlgorithm}
@classmethod
def use_sharding(cls, key):
return key in [cls.SHARD_GRAD_OP, cls.FULL_SHARD]
def is_fsdp_enabled(cfg):
return "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
def get_grad_scaler(cfg):
return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler()
class FSDPWrapper(FSDP):
......@@ -143,7 +132,7 @@ def build_fsdp(
logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used")
else:
raise ValueError(
f"Invalid sharding algorithm for building FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}."
f"Invalid sharding algorithm for FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}."
)
auto_wrap_policy = (
......@@ -181,52 +170,59 @@ def build_fsdp(
return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs)
def create_ddp_model_with_sharding(cfg, model):
if not ShardingAlgorithm.is_valid(cfg.FSDP.ALGORITHM):
raise ValueError(
f"Invalid FSDP sharding algorithm. Can only be one of {[item.value for item in ShardingAlgorithm]}"
)
elif ShardingAlgorithm.use_sharding(cfg.FSDP.ALGORITHM):
@MODELING_HOOK_REGISTRY.register()
class FSDPModelingHook(mh.ModelingHook):
"""Modeling hook that wraps model in FSDP based on config"""
def apply(self, model: nn.Module) -> FSDPWrapper:
# SOLVER.AMP.ENABLED and SOLVER.AMP.PRECISION controls mixed precision for all parameters, buffers and reduce in FSDP
precision_dtype = (
parse_precision_from_string(cfg.SOLVER.AMP.PRECISION, lightning=False)
if cfg.SOLVER.AMP.ENABLED
parse_precision_from_string(self.cfg.SOLVER.AMP.PRECISION, lightning=False)
if self.cfg.SOLVER.AMP.ENABLED
else None
)
wrapped_model = build_fsdp(
model,
sharding_algorithm=cfg.FSDP.ALGORITHM,
auto_wrap_policy_name=cfg.FSDP.AUTO_WRAP_POLICY,
sharding_algorithm=self.cfg.FSDP.ALGORITHM,
auto_wrap_policy_name=self.cfg.FSDP.AUTO_WRAP_POLICY,
auto_wrap_policy_kwargs={
"min_num_params": cfg.FSDP.AUTO_WRAP_MIN_PARAMS,
"layer_names": cfg.FSDP.AUTO_WRAP_LAYER_CLS,
"min_num_params": self.cfg.FSDP.AUTO_WRAP_MIN_PARAMS,
"layer_names": self.cfg.FSDP.AUTO_WRAP_LAYER_CLS,
},
use_cpu_offload=cfg.FSDP.CPU_OFFLOAD,
use_backward_prefetch=cfg.FSDP.BACKWARD_PREFETCH,
use_cpu_offload=self.cfg.FSDP.CPU_OFFLOAD,
use_backward_prefetch=self.cfg.FSDP.BACKWARD_PREFETCH,
param_dtype=precision_dtype,
reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype,
state_dict_cpu_offload=cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=cfg.FSDP.STATE_DICT_RANK0_ONLY,
state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
device_id=torch.cuda.current_device(),
)
else:
wrapped_model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
return wrapped_model
def unapply(self, model: FSDPWrapper) -> nn.Module:
raise NotImplementedError(
"FSDPModelingHook.unapply() not implemented: can't unwrap a FSDP module"
)
return wrapped_model
def get_grad_scaler(fsdp_algorithm):
return (
ShardedGradScaler()
if ShardingAlgorithm.use_sharding(fsdp_algorithm)
else GradScaler()
)
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
......
......@@ -9,6 +9,7 @@ import logging
import sys
from typing import List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode
from d2go.distributed import launch
from d2go.runner import BaseRunner
......@@ -22,12 +23,13 @@ from d2go.setup import (
setup_root_logger,
)
from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding
from d2go.trainer.fsdp import is_fsdp_enabled
from d2go.utils.misc import (
dump_trained_model_configs,
print_metrics_table,
save_binary_outputs,
)
from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net")
......@@ -62,14 +64,23 @@ def main(
metrics=metrics,
)
wrapped_model = create_ddp_model_with_sharding(cfg, model)
# Use DDP if FSDP is not enabled
# TODO (T142223289): rewrite ddp wrapping as modeling hook
if not is_fsdp_enabled(cfg):
model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
trained_cfgs = runner.do_train(cfg, wrapped_model, resume=resume)
trained_cfgs = runner.do_train(cfg, model, resume=resume)
final_eval = cfg.TEST.FINAL_EVAL
if final_eval:
# run evaluation after training in the same processes
metrics = runner.do_test(cfg, wrapped_model)
metrics = runner.do_test(cfg, model)
print_metrics_table(metrics)
else:
metrics = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment