"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "aed7499a8d81de78bb1692d7a0745d3890618b0e"
Commit 02625ff8 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Integrate PyTorch Fully Sharded Data Parallel (FSDP)

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/396

Integrate PyTorch FSDP, which supports two sharding modes: 1. gradient + optimizer sharding; 2. full model sharding (params + gradient + optimizer). This feature is enabled in the train_net.py code path.

Sources
* Integration follows this tutorial: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html

API changes
* Add new config keys to support the new feature. Refer to mobile-vision/d2go/d2go/trainer/fsdp.py for the full list of config options
* Add `FSDPCheckpointer` as an inheritance of `QATCheckpointer` to support special loading/saving logic for FSDP models

Reviewed By: wat3rBro

Differential Revision: D39228316

fbshipit-source-id: 342ecb3bcbce748453c3fba2d6e1b7b7e478473c
parent 0316fed4
from .fsdp_checkpoint import FSDPCheckpointer
__all__ = ["FSDPCheckpointer"]
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os
import detectron2.utils.comm as comm
import torch
from d2go.modeling.model_ema import EMAState
from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
# TODO: replace FSDPCheckpointer with central D2GoCheckpointer
class FSDPCheckpointer(QATCheckpointer):
"""
Extend the Checkpointer to support saving/loading FSDP models
"""
def load(self, path: str, checkpointables=None):
"""
Add support for loading sharded optimizer states in FSDP.
.. note:: Loading optimizer states from regular checkpoints into FSDP models is currently not supported.
In general users should not resume regular training with FSDP.
"""
if isinstance(self.model, FSDPWrapper):
checkpointables_iter = (
self.checkpointables.keys()
if checkpointables is None
else checkpointables
)
checkpointables_filtered = [
name
for name in checkpointables_iter
if name not in ["optimizer", "ema_state"]
]
checkpoint = super().load(path, checkpointables=checkpointables_filtered)
if "optimizer" in checkpointables_iter:
self.logger.info("Loading optimizer from {} ...".format(path))
osd = checkpoint.pop("optimizer")
sharded_osd = FSDP.shard_full_optim_state_dict(osd, self.model)
self.checkpointables["optimizer"].load_state_dict(sharded_osd)
if "ema_state" in checkpointables_iter:
self.logger.info("Loading ema_state from {} ...".format(path))
ema_state = checkpoint.pop("ema_state")
scatter_ema_state_dict(ema_state, self.model)
# return all remaining checkpoints
return checkpoint
else:
return super().load(path, checkpointables=checkpointables)
def save(self, name: str, **kwargs) -> None:
"""
Add support for saving sharding models and optimizers.
The rest of the code is copied from implementation in the superclass
"""
# If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDP) and not comm.is_main_process():
return
data = {}
# FSDP: model.state_dict() needs to be called by all ranks before saving
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
if key == "optimizer":
data[key] = gather_optimizer_state_dict(obj, self.model)
elif key == "ema_state":
data[key] = gather_ema_state_dict(obj, self.model)
else:
data[key] = obj.state_dict()
data.update(kwargs)
# Only the main process does checkpoint saving; code copied from vision/fair/fvcore/fvcore/common/checkpoint.py
if comm.is_main_process():
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with self.path_manager.open(save_file, "wb") as f:
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[typing.Any],
# IO[bytes], str, BinaryIO]` but got `Union[IO[bytes], IO[str]]`.
torch.save(data, f)
self.tag_last_checkpoint(basename)
def gather_optimizer_state_dict(optimizer, model=None):
# FSDP: full_optim_state_dict() needs to be called by all ranks
if isinstance(model, FSDPWrapper):
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only)
return optimizer.state_dict()
def gather_ema_state_dict(ema_state, model):
"""
Get EMA state dict.
For FSDP, gather local sharded EMA states from all FSDP processes and aggregate them into a FULL GLOBAL state dict
"""
if isinstance(model, FSDPWrapper):
# Apply local ema states to the model and unshard them
with ema_state.apply_and_restore(model):
with FSDP.summon_full_params(
model,
writeback=False,
offload_to_cpu=model.offload_to_cpu,
rank0_only=model.rank0_only,
):
state = EMAState.FromModel(model)
return state.state
else:
return ema_state.state_dict()
def scatter_ema_state_dict(ema_state_dict, model):
"""
Load an EMA state dict to the model.
EMA state represents a FULL GLOBAL state dict and needs to be properly sharded for each FSDP process to store locally
"""
if isinstance(model, FSDPWrapper):
# Store the current model state.
old_local_state = EMAState.FromModel(model)
# Apply ema_state as a FULL state dict to the model so it can be properly sharded
# Currently only [offload_to_cpu=False, rank0_only=False] is supported
with FSDP.summon_full_params(
model,
writeback=True,
offload_to_cpu=False,
rank0_only=False,
):
ema_state = EMAState()
ema_state.load_state_dict(ema_state_dict)
ema_state.apply_to(model)
# Load ema_state from model
model.ema_state.save_from(model)
# Restore the old model state
old_local_state.apply_to(model)
else:
model.ema_state.load_state_dict(ema_state_dict)
...@@ -246,9 +246,10 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -246,9 +246,10 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
Build an optimizer from config. Build an optimizer from config.
""" """
params = get_optimizer_param_groups(model, cfg) params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params, params=params,
cfg.SOLVER.BASE_LR, lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM, momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV, nesterov=cfg.SOLVER.NESTEROV,
foreach=True, foreach=True,
...@@ -262,10 +263,9 @@ def adam(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -262,10 +263,9 @@ def adam(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
params = get_optimizer_param_groups(model, cfg) params = get_optimizer_param_groups(model, cfg)
optim = maybe_add_gradient_clipping(cfg, torch.optim.Adam)( return maybe_add_gradient_clipping(cfg, torch.optim.Adam)(
params, cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS params=params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS
) )
return optim
@D2GO_OPTIM_MAPPER_REGISTRY.register() @D2GO_OPTIM_MAPPER_REGISTRY.register()
...@@ -275,10 +275,9 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -275,10 +275,9 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
params = get_optimizer_param_groups(model, cfg) params = get_optimizer_param_groups(model, cfg)
optim = maybe_add_gradient_clipping(cfg, torch.optim.AdamW)( return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS params=params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS
) )
return optim
@D2GO_OPTIM_MAPPER_REGISTRY.register() @D2GO_OPTIM_MAPPER_REGISTRY.register()
...@@ -291,8 +290,8 @@ def sgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -291,8 +290,8 @@ def sgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
params = get_optimizer_param_groups(model, cfg) params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.SGD)( return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.SGD)(
params, params=params,
cfg.SOLVER.BASE_LR, lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM, momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV, nesterov=cfg.SOLVER.NESTEROV,
) )
...@@ -308,7 +307,7 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -308,7 +307,7 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
params = get_optimizer_param_groups(model, cfg) params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.AdamW)( return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.AdamW)(
params, cfg.SOLVER.BASE_LR params=params, lr=cfg.SOLVER.BASE_LR
) )
......
...@@ -15,6 +15,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs ...@@ -15,6 +15,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs from d2go.quantization.modeling import add_quantization_default_configs
from d2go.trainer.fsdp import add_fsdp_configs
from d2go.utils.visualization import add_tensorboard_default_configs from d2go.utils.visualization import add_tensorboard_default_configs
from detectron2.config import get_cfg as get_d2_cfg from detectron2.config import get_cfg as get_d2_cfg
from mobile_cv.common.misc.oss_utils import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
...@@ -58,6 +59,8 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -58,6 +59,8 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
add_fcos_configs(_C) add_fcos_configs(_C)
# _C.DISTILLATION # _C.DISTILLATION
add_distillation_configs(_C) add_distillation_configs(_C)
# _C.FSDP
add_fsdp_configs(_C)
# Set find_unused_parameters for DistributedDataParallel. # Set find_unused_parameters for DistributedDataParallel.
_C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False _C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False
......
...@@ -11,6 +11,7 @@ from typing import List, Optional, Type, Union ...@@ -11,6 +11,7 @@ from typing import List, Optional, Type, Union
import d2go.utils.abnormal_checker as abnormal_checker import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.checkpoint import FSDPCheckpointer
from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost
from d2go.config.utils import get_cfg_diff_table from d2go.config.utils import get_cfg_diff_table
from d2go.data.build import build_d2go_train_loader from d2go.data.build import build_d2go_train_loader
...@@ -28,13 +29,14 @@ from d2go.modeling import kmeans_anchors, model_ema ...@@ -28,13 +29,14 @@ from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model from d2go.quantization.modeling import QATHook, setup_qat_model
from d2go.runner.config_defaults import ( from d2go.runner.config_defaults import (
get_base_runner_default_cfg, get_base_runner_default_cfg,
get_detectron2go_runner_default_cfg, get_detectron2go_runner_default_cfg,
get_generalized_rcnn_runner_default_cfg, get_generalized_rcnn_runner_default_cfg,
) )
from d2go.runner.training_hooks import update_hooks_from_registry from d2go.runner.training_hooks import update_hooks_from_registry
from d2go.trainer.fsdp import get_grad_scaler
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
...@@ -269,7 +271,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -269,7 +271,7 @@ class Detectron2GoRunner(BaseRunner):
def build_checkpointer(self, cfg, model, save_dir, **kwargs): def build_checkpointer(self, cfg, model, save_dir, **kwargs):
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model))
checkpointer = QATCheckpointer(model, save_dir=save_dir, **kwargs) checkpointer = FSDPCheckpointer(model, save_dir=save_dir, **kwargs)
return checkpointer return checkpointer
def build_optimizer(self, cfg, model): def build_optimizer(self, cfg, model):
...@@ -470,6 +472,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -470,6 +472,7 @@ class Detectron2GoRunner(BaseRunner):
_get_model_with_abnormal_checker(model), _get_model_with_abnormal_checker(model),
data_loader, data_loader,
optimizer, optimizer,
grad_scaler=get_grad_scaler(cfg.FSDP.ALGORITHM),
precision=parse_precision_from_string( precision=parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=False cfg.SOLVER.AMP.PRECISION, lightning=False
), ),
...@@ -608,7 +611,10 @@ class Detectron2GoRunner(BaseRunner): ...@@ -608,7 +611,10 @@ class Detectron2GoRunner(BaseRunner):
scheduler.step() scheduler.step()
# Note: when precise BN is enabled, some checkpoints will have more precise # Note: when precise BN is enabled, some checkpoints will have more precise
# statistics than others, if they are saved immediately after eval. # statistics than others, if they are saved immediately after eval.
if comm.is_main_process(): # Note: FSDP requires all ranks to execute saving/loading logic
if comm.is_main_process() or isinstance(
periodic_checkpointer.checkpointer, FSDPCheckpointer
):
periodic_checkpointer.step(trainer.iter) periodic_checkpointer.step(trainer.iter)
return hooks.CallbackHook(after_step=after_step_callback) return hooks.CallbackHook(after_step=after_step_callback)
......
#!/usr/bin/env python3
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import logging
from enum import Enum
from functools import partial
from typing import Callable, Iterable, Optional
import detectron2.utils.comm as comm
import torch
from d2go.config import CfgNode as CN
from d2go.trainer.helper import parse_precision_from_string
from detectron2.engine.defaults import create_ddp_model
from detectron2.utils.registry import Registry
from torch.cuda.amp import GradScaler
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)
logger = logging.getLogger(__name__)
D2GO_FSDP_WRAP_POLICY_REGISTRY = Registry("D2GO_FSDP_WRAP_POLICY_REGISTRY")
def add_fsdp_configs(_C: CN):
_C.FSDP = CN()
_C.FSDP.ALGORITHM = "" # 'grad_optim', 'full' or ''
# Configs for fully sharded data parallel (fsdp)
# Check out https://pytorch.org/docs/stable/fsdp.html
# and docstring of torch.distributed.fsdp.fully_sharded_data_parallel
# See docstring of CpuOffload and BackwardPrefetch in torch.distributed.fsdp.fully_sharded_data_parallel
_C.FSDP.CPU_OFFLOAD = False
_C.FSDP.BACKWARD_PREFETCH = True
# Find autowrap policy at D2GO_FSDP_WRAP_POLICY_REGISTRY, or use '' to disable autowrap
_C.FSDP.AUTO_WRAP_POLICY = ""
_C.FSDP.AUTO_WRAP_MIN_PARAMS = int(1e4)
# A list of layer cls names to wrap, case sensitive
_C.FSDP.AUTO_WRAP_LAYER_CLS = []
# Whether to offload state dict to cpu
_C.FSDP.STATE_DICT_CPU_OFFLOAD = True
# Whether to materialize state dict on rank 0
_C.FSDP.STATE_DICT_RANK0_ONLY = True
class ShardingAlgorithm(str, Enum):
SHARD_GRAD_OP = "grad_optim"
FULL_SHARD = "full"
NO_SHARD = ""
@classmethod
def is_valid(cls, key):
return key in {item.value for item in ShardingAlgorithm}
@classmethod
def use_sharding(cls, key):
return key in [cls.SHARD_GRAD_OP, cls.FULL_SHARD]
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name. Code borrowed from HuggingFace
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
class FSDPWrapper(FSDP):
def __init__(
self,
model,
state_dict_cpu_offload=True,
state_dict_rank0_only=True,
**fsdp_kwargs,
):
self.offload_to_cpu = state_dict_cpu_offload
self.rank0_only = state_dict_rank0_only
super().__init__(model, **fsdp_kwargs)
def state_dict(self, *args, **kwargs):
# TODO: support local state dict
# NOTE: model.state_dict() needs to be called by all ranks because synchronization primitives are used
save_policy = FullStateDictConfig(
offload_to_cpu=self.offload_to_cpu, rank0_only=self.rank0_only
)
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT, save_policy):
return super().state_dict(*args, **kwargs)
def load_state_dict(
self,
state_dict,
*args,
**kwargs,
):
with FSDP.state_dict_type(self, StateDictType.FULL_STATE_DICT):
return super().load_state_dict(state_dict, *args, **kwargs)
def build_fsdp(
model,
*,
sharding_algorithm: str = ShardingAlgorithm.FULL_SHARD,
auto_wrap_policy_name: str = "",
auto_wrap_policy_kwargs: Optional[dict] = None,
use_cpu_offload: bool = False,
use_backward_prefetch: bool = True,
param_dtype: Optional[torch.dtype] = None,
reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
device_id: Optional[int] = None,
):
if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
logger.info("Optimizer + Gradient State Sharding (ZeRO-2) is used")
elif sharding_algorithm == ShardingAlgorithm.FULL_SHARD:
sharding_strategy = ShardingStrategy.FULL_SHARD
logger.info("Optimizer + Gradient + Horizontal Model Sharding (ZeRO-3) is used")
else:
raise ValueError(
f"Invalid sharding algorithm for building FSDP. Can be either {ShardingAlgorithm.SHARD_GRAD_OP} or {ShardingAlgorithm.FULL_SHARD}."
)
auto_wrap_policy = (
D2GO_FSDP_WRAP_POLICY_REGISTRY.get(auto_wrap_policy_name)(
model, **auto_wrap_policy_kwargs
)
if auto_wrap_policy_name != ""
else None
)
cpu_offload = CPUOffload(offload_params=use_cpu_offload)
mixed_precision = MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
keep_low_precision_grads=False,
)
backward_prefetch = (
BackwardPrefetch.BACKWARD_PRE
if use_backward_prefetch
else BackwardPrefetch.BACKWARD_POST
)
fsdp_kwargs = {
"sharding_strategy": sharding_strategy,
"cpu_offload": cpu_offload,
"mixed_precision": mixed_precision,
"auto_wrap_policy": auto_wrap_policy,
"backward_prefetch": backward_prefetch,
"device_id": torch.cuda.current_device() if not device_id else device_id,
}
wrapper_kwargs = {
"state_dict_cpu_offload": state_dict_cpu_offload,
"state_dict_rank0_only": state_dict_rank0_only,
}
return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs)
def create_ddp_model_with_sharding(cfg, model):
if not ShardingAlgorithm.is_valid(cfg.FSDP.ALGORITHM):
raise ValueError(
f"Invalid FSDP sharding algorithm. Can only be one of {[item.value for item in ShardingAlgorithm]}"
)
elif ShardingAlgorithm.use_sharding(cfg.FSDP.ALGORITHM):
# SOLVER.AMP.ENABLED and SOLVER.AMP.PRECISION controls mixed precision for all parameters, buffers and reduce in FSDP
precision_dtype = (
parse_precision_from_string(cfg.SOLVER.AMP.PRECISION, lightning=False)
if cfg.SOLVER.AMP.ENABLED
else None
)
wrapped_model = build_fsdp(
model,
sharding_algorithm=cfg.FSDP.ALGORITHM,
auto_wrap_policy_name=cfg.FSDP.AUTO_WRAP_POLICY,
auto_wrap_policy_kwargs={
"min_num_params": cfg.FSDP.AUTO_WRAP_MIN_PARAMS,
"layer_names": cfg.FSDP.AUTO_WRAP_LAYER_CLS,
},
use_cpu_offload=cfg.FSDP.CPU_OFFLOAD,
use_backward_prefetch=cfg.FSDP.BACKWARD_PREFETCH,
param_dtype=precision_dtype,
reduce_dtype=precision_dtype,
buffer_dtype=precision_dtype,
state_dict_cpu_offload=cfg.FSDP.STATE_DICT_CPU_OFFLOAD,
state_dict_rank0_only=cfg.FSDP.STATE_DICT_RANK0_ONLY,
device_id=torch.cuda.current_device(),
)
else:
wrapped_model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
return wrapped_model
def get_grad_scaler(fsdp_algorithm):
return (
ShardedGradScaler()
if ShardingAlgorithm.use_sharding(fsdp_algorithm)
else GradScaler()
)
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def always_wrap_policy(model, **kwargs) -> Optional[Callable]:
"""
Wrapper for always_wrap_policy() from torch.distributed.fsdp.wrap
"""
return _always_wrap_policy
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def size_based_auto_wrap_policy(
model, min_num_params=1e4, **kwargs
) -> Optional[Callable]:
"""
Wrapper for size_based_auto_wrap_policy() from torch.distributed.fsdp.wrap
"""
# Note: be careful when using auto wrap with shared parameters.
# Errors will be thrown if shared parameters reside in different FSDP units
return partial(
_size_based_auto_wrap_policy,
min_num_params=min_num_params,
)
@D2GO_FSDP_WRAP_POLICY_REGISTRY.register()
def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
assert (
len(layer_names) > 0
), "FSDP.AUTO_WRAP_LAYER_CLS should be a nonempty list of layer names contained in the model"
layer_cls = []
for name in layer_names:
closure = get_module_class_from_name(model, name)
if closure is None:
raise Exception(
f"Could not find the layer class {name} to wrap in the model."
)
layer_cls.append(closure)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
...@@ -121,6 +121,14 @@ def add_flop_printing_hook( ...@@ -121,6 +121,14 @@ def add_flop_printing_hook(
@PROFILER_REGISTRY.register() @PROFILER_REGISTRY.register()
def default_flop_counter(model, cfg): def default_flop_counter(model, cfg):
from d2go.trainer.fsdp import FSDP
# TODO: deepcopy() not supported for FSDP yet (https://github.com/pytorch/pytorch/issues/82070), so we disable flop counter for now
if isinstance(model, FSDP):
logger.warn(
"Default flop counter is disabled because it's not supported for FSDP yet. "
)
return
return add_flop_printing_hook(model, cfg.OUTPUT_DIR) return add_flop_printing_hook(model, cfg.OUTPUT_DIR)
......
...@@ -23,12 +23,12 @@ from d2go.setup import ( ...@@ -23,12 +23,12 @@ from d2go.setup import (
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import TrainNetOutput from d2go.trainer.api import TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
print_metrics_table, print_metrics_table,
save_binary_outputs, save_binary_outputs,
) )
from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
...@@ -64,20 +64,14 @@ def main( ...@@ -64,20 +64,14 @@ def main(
metrics=metrics, metrics=metrics,
) )
model = create_ddp_model( wrapped_model = create_ddp_model_with_sharding(cfg, model)
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
trained_cfgs = runner.do_train(cfg, model, resume=resume) trained_cfgs = runner.do_train(cfg, wrapped_model, resume=resume)
final_eval = cfg.TEST.FINAL_EVAL final_eval = cfg.TEST.FINAL_EVAL
if final_eval: if final_eval:
# run evaluation after training in the same processes # run evaluation after training in the same processes
metrics = runner.do_test(cfg, model) metrics = runner.do_test(cfg, wrapped_model)
print_metrics_table(metrics) print_metrics_table(metrics)
else: else:
metrics = {} metrics = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment