Commit 0389f4ee authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Enable activation checkpointing

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/573

Enable Activation Checkpointing from Pytorch Distributed in d2go.

Reviewed By: rohan-varma

Differential Revision: D45681009

fbshipit-source-id: c03f27af61e0374b9e5991d82070edbe41edde6d
parent 3fce52cf
...@@ -16,6 +16,7 @@ from d2go.modeling.model_freezing_utils import add_model_freezing_configs ...@@ -16,6 +16,7 @@ from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs from d2go.quantization.modeling import add_quantization_default_configs
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.trainer.activation_checkpointing import add_activation_checkpoint_configs
from d2go.trainer.fsdp import add_fsdp_configs from d2go.trainer.fsdp import add_fsdp_configs
from d2go.utils.gpu_memory_profiler import add_memory_profiler_configs from d2go.utils.gpu_memory_profiler import add_memory_profiler_configs
from d2go.utils.visualization import add_tensorboard_default_configs from d2go.utils.visualization import add_tensorboard_default_configs
...@@ -87,6 +88,8 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -87,6 +88,8 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
add_distillation_configs(_C) add_distillation_configs(_C)
# _C.FSDP # _C.FSDP
add_fsdp_configs(_C) add_fsdp_configs(_C)
# _C.ACTIVATION_CHECKPOINT
add_activation_checkpoint_configs(_C)
# Set find_unused_parameters for DistributedDataParallel. # Set find_unused_parameters for DistributedDataParallel.
_C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False _C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False
......
#!/usr/bin/env python3
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import logging
from functools import partial
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import D2GO_WRAP_POLICY_REGISTRY
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
logger = logging.getLogger(__name__)
def add_activation_checkpoint_configs(_C: CN):
_C.ACTIVATION_CHECKPOINT = CN()
_C.ACTIVATION_CHECKPOINT.REENTRANT = False
# Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "always_wrap_policy"
# A list of layer cls names to wrap, case sensitive
_C.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = []
@MODELING_HOOK_REGISTRY.register()
class ActivationCheckpointModelingHook(mh.ModelingHook):
"""Modeling hook that wraps model in activation checkpoint based on config"""
def apply(self, model: nn.Module) -> nn.Module:
logger.info("Activation Checkpointing is used")
wrapper_fn = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT
if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT
else CheckpointImpl.REENTRANT,
)
policy_name = self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY
assert (
policy_name != "size_based_auto_wrap_policy"
), "ActivationCheckpointing should always be wrapped at module boundary"
policy_kwargs = {
"layer_names": self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS,
}
auto_wrap_policy = (
D2GO_WRAP_POLICY_REGISTRY.get(policy_name)(model, **policy_kwargs)
if policy_name != ""
else lambda _: True
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=wrapper_fn, auto_wrap_policy=auto_wrap_policy
)
return model
def unapply(self, model: nn.Module) -> nn.Module:
raise NotImplementedError(
"ActivationCheckpointModelingHook.unapply() not implemented: can't unwrap an activation checkpoint module"
)
...@@ -3,16 +3,14 @@ ...@@ -3,16 +3,14 @@
import contextlib import contextlib
import logging import logging
from enum import Enum from enum import Enum
from functools import partial from typing import Generator, Optional
from typing import Callable, Generator, Iterable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.modeling.modeling_hook import ModelingHook from d2go.modeling.modeling_hook import ModelingHook
from d2go.registry.builtin import MODELING_HOOK_REGISTRY from d2go.registry.builtin import MODELING_HOOK_REGISTRY
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import D2GO_WRAP_POLICY_REGISTRY, parse_precision_from_string
from detectron2.utils.registry import Registry
from torch.ao.pruning import fqn_to_module from torch.ao.pruning import fqn_to_module
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
...@@ -27,17 +25,10 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -27,17 +25,10 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
StateDictType, StateDictType,
) )
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY")
def add_fsdp_configs(_C: CN): def add_fsdp_configs(_C: CN):
_C.FSDP = CN() _C.FSDP = CN()
...@@ -49,7 +40,7 @@ def add_fsdp_configs(_C: CN): ...@@ -49,7 +40,7 @@ def add_fsdp_configs(_C: CN):
# See docstring of CpuOffload and BackwardPrefetch in torch.distributed.fsdp.fully_sharded_data_parallel # See docstring of CpuOffload and BackwardPrefetch in torch.distributed.fsdp.fully_sharded_data_parallel
_C.FSDP.CPU_OFFLOAD = False _C.FSDP.CPU_OFFLOAD = False
_C.FSDP.BACKWARD_PREFETCH = True _C.FSDP.BACKWARD_PREFETCH = True
# Find autowrap policy at D2GO_FSDP_WRAP_POLICY_REGISTRY, or use '' to disable autowrap # Find autowrap policy at D2GO_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy" _C.FSDP.AUTO_WRAP_POLICY = "never_wrap_policy"
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4) _C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive # A list of layer cls names to wrap, case sensitive
...@@ -210,7 +201,7 @@ def build_fsdp( ...@@ -210,7 +201,7 @@ def build_fsdp(
) )
auto_wrap_policy = ( auto_wrap_policy = (
D2GO_FSDP_WRAP_POLICY_REGISTRY.get(auto_wrap_policy_name)( D2GO_WRAP_POLICY_REGISTRY.get(auto_wrap_policy_name)(
model, **auto_wrap_policy_kwargs model, **auto_wrap_policy_kwargs
) )
if auto_wrap_policy_name != "" if auto_wrap_policy_name != ""
...@@ -321,83 +312,3 @@ class FSDPModelingHook(ModelingHook): ...@@ -321,83 +312,3 @@ class FSDPModelingHook(ModelingHook):
raise NotImplementedError( raise NotImplementedError(
"FSDPModelingHook.unapply() not implemented: can't unwrap a FSDP module" "FSDPModelingHook.unapply() not implemented: can't unwrap a FSDP module"
) )
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def never_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Don't wrap any child module, only wrap the root
"""
def never_wrap(*args, **kwargs):
return False
return never_wrap
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Wrapper for always_wrap_policy() from torch.distributed.fsdp.wrap
"""
return _always_wrap_policy
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def size_based_auto_wrap_policy(
model, min_num_params=1e4, **kwargs
) -> Optional[Callable]:
"""
Wrapper for size_based_auto_wrap_policy() from torch.distributed.fsdp.wrap
"""
# Note: be careful when using auto wrap with shared parameters.
# Errors will be thrown if shared parameters reside in different FSDP units
return partial(
_size_based_auto_wrap_policy,
min_num_params=min_num_params,
)
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
assert (
len(layer_names) > 0
), "FSDP.AUTO_WRAP_LAYER_CLS should be a nonempty list of layer names contained in the model"
layer_cls = []
for name in layer_names:
closure = get_module_class_from_name(model, name)
if closure is None:
raise Exception(
f"Could not find the layer class {name} to wrap in the model."
)
layer_cls.append(closure)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
from typing import Union from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
import torch import torch
from detectron2.utils.registry import Registry
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)
D2GO_WRAP_POLICY_REGISTRY = Registry("D2GO_WRAP_POLICY_REGISTRY")
def parse_precision_from_string( def parse_precision_from_string(
precision: str, lightning=False precision: str, lightning=False
...@@ -19,3 +30,94 @@ def parse_precision_from_string( ...@@ -19,3 +30,94 @@ def parse_precision_from_string(
return torch.bfloat16 if not lightning else "bf16" return torch.bfloat16 if not lightning else "bf16"
else: else:
raise ValueError(f"Invalid precision dtype {precision}") raise ValueError(f"Invalid precision dtype {precision}")
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
def get_layer_cls_from_names(
model: Any, layer_names: Iterable[str]
) -> List[torch.nn.Module]:
"""
Get a list of layers from a model that match a list of layer names.
"""
layer_cls = []
for name in layer_names:
closure = get_module_class_from_name(model, name)
if closure is None:
raise Exception(
f"Could not find the layer class {name} to wrap in the model."
)
layer_cls.append(closure)
return layer_cls
@D2GO_WRAP_POLICY_REGISTRY.register()
def never_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Don't wrap any child module, only wrap the root
"""
def never_wrap(*args, **kwargs):
return False
return never_wrap
@D2GO_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Wrapper for always_wrap_policy() from torch.distributed.fsdp.wrap
"""
return _always_wrap_policy
@D2GO_WRAP_POLICY_REGISTRY.register()
def size_based_auto_wrap_policy(
model, min_num_params=1e4, **kwargs
) -> Optional[Callable]:
"""
Wrapper for size_based_auto_wrap_policy() from torch.distributed.fsdp.wrap
"""
# Note: be careful when using auto wrap with shared parameters.
# Errors will be thrown if shared parameters reside in different FSDP units
return partial(
_size_based_auto_wrap_policy,
min_num_params=min_num_params,
)
@D2GO_WRAP_POLICY_REGISTRY.register()
def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
assert (
len(layer_names) > 0
), "layer_names should be a nonempty list of layer names contained in the model"
layer_cls = get_layer_cls_from_names(model, layer_names)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os
import unittest
from typing import Dict, List
import torch
from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.runner.default_runner import Detectron2GoRunner
from d2go.trainer.activation_checkpointing import (
ActivationCheckpointModelingHook,
add_activation_checkpoint_configs,
)
from d2go.utils.testing.data_loader_helper import create_local_dataset
from d2go.utils.testing.helper import tempdir
from detectron2.structures import ImageList
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
@META_ARCH_REGISTRY.register()
class MetaArchForTestAC(torch.nn.Module):
def __init__(self, cfg: CfgNode) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1)
self.bn = torch.nn.BatchNorm2d(4)
self.relu = torch.nn.ReLU(inplace=True)
self.linear = torch.nn.Linear(4, 4)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
@property
def device(self) -> torch._C.device:
return self.conv1.weight.device
def forward(self, inputs: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
images = [x["image"] for x in inputs]
images = ImageList.from_tensors(images, 1)
ret = self.conv(images.tensor)
ret = self.bn(ret)
ret = self.relu(ret)
ret = self.avgpool(ret)
return {"loss": ret.norm()}
def _get_cfg(runner, output_dir, dataset_name):
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "MetaArchForTestAC"
cfg.DATASETS.TRAIN = (dataset_name,)
cfg.DATASETS.TEST = (dataset_name,)
cfg.INPUT.MIN_SIZE_TRAIN = (10,)
cfg.INPUT.MIN_SIZE_TEST = (10,)
cfg.SOLVER.MAX_ITER = 3
cfg.SOLVER.STEPS = []
cfg.SOLVER.WARMUP_ITERS = 1
cfg.SOLVER.CHECKPOINT_PERIOD = 3
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.MODEL_EMA.ENABLED = True
cfg.OUTPUT_DIR = output_dir
return cfg
class TestActivationCheckpointing(unittest.TestCase):
def test_ac_config(self) -> None:
cfg = CfgNode()
add_activation_checkpoint_configs(cfg)
self.assertTrue(isinstance(cfg.ACTIVATION_CHECKPOINT, CfgNode))
self.assertEqual(cfg.ACTIVATION_CHECKPOINT.REENTRANT, False)
self.assertEqual(
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY, "always_wrap_policy"
)
self.assertEqual(cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS, [])
def test_ac_modeling_hook_apply(self) -> None:
"""Check that the hook is registered"""
self.assertTrue("ActivationCheckpointModelingHook" in mh.MODELING_HOOK_REGISTRY)
cfg = CfgNode()
add_activation_checkpoint_configs(cfg)
ac_hook = ActivationCheckpointModelingHook(cfg)
model = MetaArchForTestAC(cfg)
ac_hook.apply(model)
children = list(model.children())
self.assertTrue(len(children) == 5)
for child in children:
self.assertTrue(isinstance(child, CheckpointWrapper))
def test_ac_modeling_hook_autowrap(self) -> None:
cfg = CfgNode()
add_activation_checkpoint_configs(cfg)
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "layer_based_auto_wrap_policy"
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = ["Conv2d", "BatchNorm2d"]
ac_hook = ActivationCheckpointModelingHook(cfg)
model = MetaArchForTestAC(cfg)
ac_hook.apply(model)
self.assertTrue(isinstance(model.conv, CheckpointWrapper))
self.assertTrue(isinstance(model.bn, CheckpointWrapper))
self.assertFalse(isinstance(model.linear, CheckpointWrapper))
@tempdir
def test_ac_runner(self, tmp_dir) -> None:
tmp_dir = "/tmp/test"
os.makedirs(tmp_dir, exist_ok=True)
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = Detectron2GoRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
cfg.MODEL.MODELING_HOOKS = ["ActivationCheckpointModelingHook"]
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "layer_based_auto_wrap_policy"
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = ["Conv2d", "BatchNorm2d"]
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "model_0000002.pth")))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment