Commit dc6fac12 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Rewrite FSDP wrapping as modeling hook

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/440

Move FSDP wrapping to runner.build_model by rewriting it as a modeling hook

**Motivation**
When a model is too large to run inference on a single GPU, it requires using FSDP with local checkpointing mode to save peak GPU memory. However, in eval_pytorch workflow (train_net with eval-only), models are evaluated without being wrapped by FSDP. This may cause OOM errors for the reasons above. Thus, it may be a better practice to wrap model with FSDP during `runner.build_model(cfg)`, so evaluation can also be run in the same FSDP setting as in training.

This diff moves FSDP wrapping to `runner.build_model(cfg)` by rewriting it as a modeling hook.

**API changes**
* Users need to append `"FSDPModelingHook"` to `MODEL.MODELING_HOOKS` to enable FSDP.
* `FSDP.ALGORITHM` can only be `full` or `grad_optim`

**Note**
It's not possible to unwrap an FSDP model back to the normal model, so FSDPModelingHook.unapply() can't be implemented

Reviewed By: wat3rBro

Differential Revision: D41416917

fbshipit-source-id: f3fc72d574cc6ccbe0d238e48c575926ba5b4d06
parent 7b2ba6cb
...@@ -59,7 +59,7 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -59,7 +59,7 @@ class FSDPCheckpointer(QATCheckpointer):
""" """
# If no sharding, only the main process enters the saving codepath; # If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks # otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDP) and not comm.is_main_process(): if not isinstance(self.model, FSDPWrapper) and not comm.is_main_process():
return return
data = {} data = {}
......
...@@ -517,7 +517,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -517,7 +517,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
_get_model_with_abnormal_checker(model), _get_model_with_abnormal_checker(model),
data_loader, data_loader,
optimizer, optimizer,
grad_scaler=get_grad_scaler(cfg.FSDP.ALGORITHM), grad_scaler=get_grad_scaler(cfg),
precision=parse_precision_from_string( precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False cfg.SOLVER.AMP.PRECISION, lightning=False
), ),
......
...@@ -7,9 +7,11 @@ from typing import Callable, Iterable, Optional ...@@ -7,9 +7,11 @@ from typing import Callable, Iterable, Optional
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
import torch.nn as nn
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from detectron2.engine.defaults import create_ddp_model
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
...@@ -36,7 +38,7 @@ D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY") ...@@ -36,7 +38,7 @@ D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY")
def add_fsdp_configs(_C: CN): def add_fsdp_configs(_C: CN):
_C.FSDP = CN() _C.FSDP = CN()
_C.FSDP.ALGORITHM = "" # 'grad_optim', 'full' or '' _C.FSDP.ALGORITHM = "" # 'grad_optim' or 'full'
# Configs for fully sharded data parallel (fsdp) # Configs for fully sharded data parallel (fsdp)
# Check out https://pytorch.org/docs/stable/fsdp.html # Check out https://pytorch.org/docs/stable/fsdp.html
...@@ -56,36 +58,23 @@ def add_fsdp_configs(_C: CN): ...@@ -56,36 +58,23 @@ def add_fsdp_configs(_C: CN):
class ShardingAlgorithm(str, Enum): class ShardingAlgorithm(str, Enum):
"""
This enum specifies the sharding algorithm to be used by FullyShardedDataParallel (FSDP).
It matches the strings used in D2Go config with the enum class :class:`ShardingStrategy` used by Pytorch FSDP module:
"grad_optim" => ShardingAlgorithm.SHARD_GRAD_OP => ShardingStrategy.SHARD_GRAD_OP
"full" => ShardingAlgorithm.FULL_SHARD => ShardingStrategy.FULL_SHARD
"""
SHARD_GRAD_OP = "grad_optim" SHARD_GRAD_OP = "grad_optim"
FULL_SHARD = "full" FULL_SHARD = "full"
NO_SHARD = ""
@classmethod
def is_valid(cls, key):
return key in {item.value for item in ShardingAlgorithm}
@classmethod def is_fsdp_enabled(cfg):
def use_sharding(cls, key): return "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS
return key in [cls.SHARD_GRAD_OP, cls.FULL_SHARD]
def get_module_class_from_name(module, name): def get_grad_scaler(cfg):
""" return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler()
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
class FSDPWrapper(FSDP): class FSDPWrapper(FSDP):
...@@ -143,7 +132,7 @@ def build_fsdp( ...@@ -143,7 +132,7 @@ def build_fsdp(
logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used") logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used")
else: else:
raise ValueError( raise ValueError(
f"Invalid sharding algorithm for building FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}." f"Invalid sharding algorithm for FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}."
) )
auto_wrap_policy = ( auto_wrap_policy = (
...@@ -181,54 +170,61 @@ def build_fsdp( ...@@ -181,54 +170,61 @@ def build_fsdp(
return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs) return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs)
def create_ddp_model_with_sharding(cfg, model): @MODELING_HOOK_REGISTRY.register()
if not ShardingAlgorithm.is_valid(cfg.FSDP.ALGORITHM): class FSDPModelingHook(mh.ModelingHook):
raise ValueError( """Modeling hook that wraps model in FSDP based on config"""
f"Invalid FSDP sharding algorithm. Can only be one of {[item.value for item in ShardingAlgorithm]}"
) def apply(self, model: nn.Module) -> FSDPWrapper:
elif ShardingAlgorithm.use_sharding(cfg.FSDP.ALGORITHM):
# SOLVER.AMP.ENABLED and SOLVER.AMP.PRECISION controls mixed precision for all parameters, buffers and reduce in FSDP # SOLVER.AMP.ENABLED and SOLVER.AMP.PRECISION controls mixed precision for all parameters, buffers and reduce in FSDP
precision_dtype = ( precision_dtype = (
parse_precision_from_string(cfg.SOLVER.AMP.PRECISION, lightning=False) parse_precision_from_string(self.cfg.SOLVER.AMP.PRECISION, lightning=False)
if cfg.SOLVER.AMP.ENABLED if self.cfg.SOLVER.AMP.ENABLED
else None else None
) )
wrapped_model = build_fsdp( wrapped_model = build_fsdp(
model, model,
sharding_algorithm=cfg.FSDP.ALGORITHM, sharding_algorithm=self.cfg.FSDP.ALGORITHM,
auto_wrap_policy_name=cfg.FSDP.AUTO_WRAP_POLICY, auto_wrap_policy_name=self.cfg.FSDP.AUTO_WRAP_POLICY,
auto_wrap_policy_kwargs={ auto_wrap_policy_kwargs={
"min_num_params": cfg.FSDP.AUTO_WRAP_MIN_PARAMS, "min_num_params": self.cfg.FSDP.AUTO_WRAP_MIN_PARAMS,
"layer_names": cfg.FSDP.AUTO_WRAP_LAYER_CLS, "layer_names": self.cfg.FSDP.AUTO_WRAP_LAYER_CLS,
}, },
use_cpu_offload=cfg.FSDP.CPU_OFFLOAD, use_cpu_offload=self.cfg.FSDP.CPU_OFFLOAD,
use_backward_prefetch=cfg.FSDP.BACKWARD_PREFETCH, use_backward_prefetch=self.cfg.FSDP.BACKWARD_PREFETCH,
param_dtype=precision_dtype, param_dtype=precision_dtype,
reduce_dtype=precision_dtype, reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype, buffer_dtype=precision_dtype,
state_dict_cpu_offload=cfg.FSDP.STATE_DICT_CPU_OFFLOAD, state_dict_cpu_offload=self.cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=cfg.FSDP.STATE_DICT_RANK0_ONLY, state_dict_rank0_only=self.cfg.FSDP.STATE_DICT_RANK0_ONLY,
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
) )
else:
wrapped_model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
return wrapped_model return wrapped_model
def unapply(self, model: FSDPWrapper) -> nn.Module:
def get_grad_scaler(fsdp_algorithm): raise NotImplementedError(
return ( "FSDPModelingHook.unapply() not implemented: can't unwrap a FSDP module"
ShardedGradScaler()
if ShardingAlgorithm.use_sharding(fsdp_algorithm)
else GradScaler()
) )
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register() @D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]: def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
""" """
......
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
import sys import sys
from typing import List, Type, Union from typing import List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
...@@ -22,12 +23,13 @@ from d2go.setup import ( ...@@ -22,12 +23,13 @@ from d2go.setup import (
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import TestNetOutput, TrainNetOutput from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding from d2go.trainer.fsdp import is_fsdp_enabled
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
print_metrics_table, print_metrics_table,
save_binary_outputs, save_binary_outputs,
) )
from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
...@@ -62,14 +64,23 @@ def main( ...@@ -62,14 +64,23 @@ def main(
metrics=metrics, metrics=metrics,
) )
wrapped_model = create_ddp_model_with_sharding(cfg, model) # Use DDP if FSDP is not enabled
# TODO (T142223289): rewrite ddp wrapping as modeling hook
if not is_fsdp_enabled(cfg):
model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
trained_cfgs = runner.do_train(cfg, wrapped_model, resume=resume) trained_cfgs = runner.do_train(cfg, model, resume=resume)
final_eval = cfg.TEST.FINAL_EVAL final_eval = cfg.TEST.FINAL_EVAL
if final_eval: if final_eval:
# run evaluation after training in the same processes # run evaluation after training in the same processes
metrics = runner.do_test(cfg, wrapped_model) metrics = runner.do_test(cfg, model)
print_metrics_table(metrics) print_metrics_table(metrics)
else: else:
metrics = {} metrics = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment