Commit f23248c0 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: f4a8ba78691d8cf46e003ef0bd2e95f170932778
parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import d2go.data.transforms.crop as tfm_crop
import d2go.data.transforms.tensor as tfm_tensor
import detectron2.data.transforms as transforms
import torch
from detectron2.data.transforms.augmentation import AugmentationList
from torch import nn
class ImagePooler(nn.Module):
"""Get a subset of image
Returns the transforms that could be used to inverse the image/boxes/keypoints
as well.
Only available for inference. The code is not tracable/scriptable.
"""
def __init__(
self,
resize_type="resize_shortest",
resize_short=None,
resize_max=None,
box_scale_factor=1.0,
):
super().__init__()
assert resize_type in ["resize_shortest", "resize", "None", None]
resizer = None
if resize_type == "resize_shortest":
resizer = transforms.ResizeShortestEdge(resize_short, resize_max)
elif resize_type == "resize":
resizer = transforms.Resize(resize_short)
self.aug = [
tfm_tensor.Tensor2Array(),
tfm_crop.CropBoxAug(box_scale_factor=box_scale_factor),
*([resizer] if resizer else []),
tfm_tensor.Array2Tensor(),
]
def forward(self, x: torch.Tensor, box: torch.Tensor):
"""box: 1 x 4 tensor in XYXY format"""
assert not self.training
assert isinstance(x, torch.Tensor)
assert isinstance(box, torch.Tensor)
# box: 1 x 4 in xyxy format
inputs = tfm_tensor.AugInput(image=x.cpu(), boxes=box.cpu())
transforms = AugmentationList(self.aug)(inputs)
return (
inputs.image.to(x.device),
torch.Tensor(inputs.boxes).to(box.device),
transforms,
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import numpy as np
import torch
from typing import List
import detectron2.utils.comm as comm
from detectron2.engine import hooks
from detectron2.layers import ShapeSpec
from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.anchor_generator import (
ANCHOR_GENERATOR_REGISTRY,
BufferList,
DefaultAnchorGenerator,
)
from detectron2.modeling.proposal_generator.rpn import RPN
from detectron2.structures.boxes import Boxes
from d2go.config import temp_defrost, CfgNode as CN
logger = logging.getLogger(__name__)
def add_kmeans_anchors_cfg(_C):
_C.MODEL.KMEANS_ANCHORS = CN()
_C.MODEL.KMEANS_ANCHORS.KMEANS_ANCHORS_ON = False
_C.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS = 0
_C.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG = 0
_C.MODEL.KMEANS_ANCHORS.DATASETS = ()
_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
_C.MODEL.KMEANS_ANCHORS.RNG_SEED = 3
return _C
def compute_kmeans_anchors_hook(runner, cfg):
"""
This function will create a before_train hook, it will:
1: create a train loader using provided KMEANS_ANCHORS.DATASETS.
2: collecting statistics of boxes using outputs from train loader, use up
to KMEANS_ANCHORS.NUM_TRAINING_IMG images.
3: compute K-means using KMEANS_ANCHORS.NUM_CLUSTERS clusters
4: update the buffers in anchor_generator.
"""
def before_train_callback(trainer):
if not cfg.MODEL.KMEANS_ANCHORS.KMEANS_ANCHORS_ON:
return
new_cfg = cfg.clone()
with temp_defrost(new_cfg):
new_cfg.DATASETS.TRAIN = cfg.MODEL.KMEANS_ANCHORS.DATASETS
data_loader = runner.build_detection_train_loader(new_cfg)
anchors = compute_kmeans_anchors(cfg, data_loader)
anchors = anchors.tolist()
assert isinstance(trainer.model, GeneralizedRCNN)
assert isinstance(trainer.model.proposal_generator, RPN)
anchor_generator = trainer.model.proposal_generator.anchor_generator
assert isinstance(anchor_generator, KMeansAnchorGenerator)
anchor_generator.update_cell_anchors(anchors)
return hooks.CallbackHook(before_train=before_train_callback)
@ANCHOR_GENERATOR_REGISTRY.register()
class KMeansAnchorGenerator(DefaultAnchorGenerator):
""" Generate anchors using pre-computed KMEANS_ANCHORS.COMPUTED_ANCHORS """
def __init__(self, cfg, input_shape: List[ShapeSpec]):
torch.nn.Module.__init__(self)
self.strides = [x.stride for x in input_shape]
self.offset = cfg.MODEL.ANCHOR_GENERATOR.OFFSET
assert 0.0 <= self.offset < 1.0, self.offset
# kmeans anchors
num_features = len(cfg.MODEL.RPN.IN_FEATURES)
assert num_features == 1, "Doesn't support multiple feature map"
# NOTE: KMEANS anchors are only computed at training time, when initialized,
# set anchors to correct shape but invalid value as place holder.
computed_anchors = [[float("Inf")] * 4] * cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS
cell_anchors = [torch.Tensor(computed_anchors)]
self.cell_anchors = BufferList(cell_anchors)
def update_cell_anchors(self, computed_anchors):
assert len(self.cell_anchors) == 1
for buf in self.cell_anchors.buffers():
assert len(buf) == len(computed_anchors)
buf.data = torch.Tensor(computed_anchors).to(buf.device)
logger.info("Updated cell anchors")
def forward(self, *args, **kwargs):
for base_anchors in self.cell_anchors:
assert torch.isfinite(base_anchors).all(), (
"The anchors are not initialized yet, please providing COMPUTED_ANCHORS"
" when creating the model and/or loading the valid weights."
)
return super().forward(*args, **kwargs)
def collect_boxes_size_stats(data_loader, max_num_imgs, _legacy_plus_one=False):
logger.info(
"Collecting size of boxes, loading up to {} images from data loader ..."
.format(max_num_imgs)
)
# data_loader might be infinite length, thus can't loop all images, using
# max_num_imgs == 0 stands for 0 images instead of whole dataset
assert max_num_imgs >= 0
box_sizes = []
remaining_num_imgs = max_num_imgs
total_batches = 0
for i, batched_inputs in enumerate(data_loader):
total_batches += len(batched_inputs)
batch_size = min(remaining_num_imgs, len(batched_inputs))
batched_inputs = batched_inputs[:batch_size]
for x in batched_inputs:
boxes = x["instances"].gt_boxes # xyxy
assert isinstance(boxes, Boxes)
for t in boxes.tensor:
box_sizes += [[t[2] - t[0], t[3] - t[1]]]
# NOTE: previous implementation didn't apply +1, thus to match
# previous (incorrect) results we have to minus the im_scale
if _legacy_plus_one: # only for matching old tests
im_scale = x["image"].shape[1] / x["height"] # image is chw
box_sizes[-1][0] -= im_scale
box_sizes[-1][1] -= im_scale
estimated_iters = max_num_imgs / total_batches * (i + 1)
remaining_num_imgs -= batch_size
if i % max(1, int(estimated_iters / 20)) == 0:
# log 20 times at most
percentage = 100.0 * i / estimated_iters
logger.info(
"Processed batch {} ({:.2f}%) from data_loader, got {} boxes,"
" remaining number of images: {}/{}"
.format(i, percentage, len(box_sizes), remaining_num_imgs, max_num_imgs)
)
if remaining_num_imgs <= 0:
assert remaining_num_imgs == 0
break
box_sizes = np.array(box_sizes)
logger.info(
"Collected {} boxes from {} images"
.format(len(box_sizes), max_num_imgs)
)
return box_sizes
def compute_kmeans_anchors(
cfg,
data_loader,
sort_by_area=True,
_stride=0,
_legacy_plus_one=False
):
assert cfg.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG > 0, \
"Please provide positive MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG"
num_training_img = cfg.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG
div_i, mod_i = divmod(num_training_img, comm.get_world_size())
num_training_img_i = div_i + (comm.get_rank() < mod_i)
box_sizes_i = collect_boxes_size_stats(
data_loader,
num_training_img_i,
_legacy_plus_one=_legacy_plus_one,
)
all_box_sizes = comm.all_gather(box_sizes_i)
box_sizes = np.concatenate(all_box_sizes)
logger.info("Collected {} boxes from all gpus".format(len(box_sizes)))
assert cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS > 0, \
"Please provide positive MODEL.KMEANS_ANCHORS.NUM_CLUSTERS"
from sklearn.cluster import KMeans # delayed import
default_anchors = (
KMeans(
n_clusters=cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS,
random_state=cfg.MODEL.KMEANS_ANCHORS.RNG_SEED,
)
.fit(box_sizes)
.cluster_centers_
)
anchors = []
for anchor in default_anchors:
w, h = anchor
# center anchor boxes at (stride/2,stride/2)
new_anchors = np.hstack(
(
_stride / 2 - 0.5 * w,
_stride / 2 - 0.5 * h,
_stride / 2 + 0.5 * w,
_stride / 2 + 0.5 * h,
)
)
anchors.append(new_anchors)
anchors = np.array(anchors)
# sort anchors by area
areas = (anchors[:, 2] - anchors[:, 0]) * (anchors[:, 3] - anchors[:, 1])
sqrt_areas = np.sqrt(areas)
if sort_by_area:
indices = np.argsort(sqrt_areas)
anchors = anchors[indices]
sqrt_areas = sqrt_areas[indices].tolist()
display_str = "\n".join([
s + "\t sqrt area: {:.2f}".format(a)
for s, a in zip(str(anchors).split("\n"), sqrt_areas)
])
logger.info(
"Compuated kmeans anchors (sorted by area: {}):\n{}"
.format(sort_by_area, display_str)
)
return anchors
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# misc.py
# modules that are used in different places but are not a specific type (e.g., backbone)
import torch
import torch.nn as nn
class SplitAndConcat(nn.Module):
"""Split the data from split_dim and concatenate in concat_dim.
@param split_dim from which axis the data will be chunk
@param concat_dim to which axis the data will be concatenated
@param chunk size of the data to be chunk/concatenated
copied: oculus/face/social_eye/lib/model/resnet_backbone.py
"""
def __init__(self, split_dim: int = 1, concat_dim: int = 0, chunk: int = 2):
super(SplitAndConcat, self).__init__()
self.split_dim = split_dim
self.concat_dim = concat_dim
self.chunk = chunk
def forward(self, x):
x = torch.chunk(x, self.chunk, dim=self.split_dim)
x = torch.cat(x, dim=self.concat_dim)
return x
def extra_repr(self):
return (
f"split_dim={self.split_dim}, concat_dim={self.concat_dim}, "
f"chunk={self.chunk}"
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import itertools
import logging
from contextlib import contextmanager
import torch
from detectron2.engine.train_loop import HookBase
logger = logging.getLogger(__name__)
class EMAState(object):
def __init__(self):
self.state = {}
@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
""" Save model state from `model` to this object """
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
""" Apply state to `model` from this object """
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])
@contextmanager
def apply_and_restore(self, model):
old_state = EMAState.FromModel(model, self.device)
self.apply_to(model)
yield old_state
old_state.apply_to(model)
def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret
@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device
def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self
def has_inited(self):
return self.state
def clear(self):
self.state.clear()
return self
def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def state_dict(self):
return self.state
def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)
def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret
class EMAUpdater(object):
""" Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
self.decay = decay
self.device = device
self.state = state
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
with torch.no_grad():
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
def add_model_ema_configs(_C):
from d2go.config import CfgNode as CN
_C.MODEL_EMA = CN()
_C.MODEL_EMA.ENABLED = False
_C.MODEL_EMA.DECAY = 0.999
# use the same as MODEL.DEVICE when empty
_C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model()
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
def _remove_ddp(model):
from torch.nn.parallel import DistributedDataParallel
if isinstance(model, DistributedDataParallel):
return model.module
return model
def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState()
logger.info("Using Model EMA.")
def may_get_ema_checkpointer(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return {}
model = _remove_ddp(model)
return {"ema_state": model.ema_state}
def get_model_ema_state(model):
""" Return the ema state stored in `model`
"""
model = _remove_ddp(model)
assert hasattr(model, "ema_state")
ema = model.ema_state
return ema
def apply_model_ema(model, state=None, save_current=False):
""" Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
if save_current:
# save current model state
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
if save_current:
return old_state
return None
@contextmanager
def apply_model_ema_and_restore(model, state=None):
""" Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
yield old_state
old_state.apply_to(model)
class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.MODEL_EMA.ENABLED
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device
)
def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)
def after_train(self):
pass
def before_step(self):
pass
def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import re
import logging
logger = logging.getLogger(__name__)
def add_model_freezing_configs(_C):
_C.MODEL.FROZEN_LAYER_REG_EXP = []
def set_requires_grad(model, reg_exps, value):
total_num_parameters = 0
unmatched_parameters = []
unmatched_parameter_names = []
matched_parameters = []
matched_parameter_names = []
for name, parameter in model.named_parameters():
total_num_parameters += 1
matched = False
for frozen_layers_regex in reg_exps:
if re.match(frozen_layers_regex, name):
matched = True
parameter.requires_grad = value
matched_parameter_names.append(name)
matched_parameters.append(parameter)
break
if not matched:
unmatched_parameter_names.append(name)
unmatched_parameters.append(parameter)
logger.info("Matched layers (require_grad={}): {}".format(
value, matched_parameter_names))
logger.info("Unmatched layers: {}".format(unmatched_parameter_names))
return matched_parameter_names, unmatched_parameter_names
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This is the centralized place to define modeldef for all projects under D2Go.
"""
from . import modeldef # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
class FBNetV2ModelArch(object):
_MODEL_ARCH = {}
@staticmethod
def add(name, arch):
assert name not in FBNetV2ModelArch._MODEL_ARCH, \
"Arch name '{}' is already existed".format(name)
FBNetV2ModelArch._MODEL_ARCH[name] = arch
@staticmethod
def add_archs(archs):
for name, arch in archs.items():
FBNetV2ModelArch.add(name, arch)
@staticmethod
def get(name):
return copy.deepcopy(FBNetV2ModelArch._MODEL_ARCH[name])
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
from mobile_cv.arch.fbnet_v2.modeldef_utils import _ex, e1, e2, e1p, e3, e4, e6
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
def _mutated_tuple(tp, pos, value):
tp_list = list(tp)
tp_list[pos] = value
return tuple(tp_list)
def _repeat_last(stage, n=None):
"""
Repeat the last "layer" of given stage, i.e. a (op_type, c, s, n_repeat, ...)
tuple, reset n_repeat if specified otherwise kept the original value.
"""
assert isinstance(stage, list)
assert all(isinstance(x, tuple) for x in stage)
last_layer = copy.deepcopy(stage[-1])
if n is not None:
last_layer = _mutated_tuple(last_layer, 3, n)
return last_layer
_BASIC_ARGS = {
# skil norm and activation for depthwise conv in IRF module, this make the
# model easier to quantize.
"dw_skip_bnrelu": True,
# uncomment below (always_pw and bias) to match model definition of the
# FBNetV1 builder.
# "always_pw": True,
# "bias": False,
# temporarily disable zero_last_bn_gamma
"zero_last_bn_gamma": False,
}
DEFAULT_STAGES = [
# NOTE: each stage is a list of (op_type, out_channels, stride, n_repeat, ...)
# resolution stage 0, equivalent to 224->112
[("conv_k3", 32, 2, 1), ("ir_k3", 16, 1, 1, e1)],
# resolution stage 1, equivalent to 112->56
[("ir_k3", 24, 2, 2, e6)],
# resolution stage 2, equivalent to 56->28
[("ir_k3", 32, 2, 3, e6)],
# resolution stage 3, equivalent to 28->14
[("ir_k3", 64, 2, 4, e6), ("ir_k3", 96, 1, 3, e6)],
# resolution stage 4, equivalent to 14->7
[("ir_k3", 160, 2, 3, e6), ("ir_k3", 320, 1, 1, e6)],
# final stage, equivalent to 7->1, ignored
]
IRF_CFG = {"less_se_channels": False}
FBNetV3_A_dsmask = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 1, {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 32, 1, 1, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5", 40, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 40, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 72, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 72, 1, 3, {"expansion": 3}, IRF_CFG),
("ir_k5", 112, 1, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 112, 1, 3, {"expansion": 4}, IRF_CFG),
],
[
("ir_k5", 184, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 184, 1, 4, {"expansion": 4}, IRF_CFG),
("ir_k5", 200, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_A_dsmask_tiny = [
[
("conv_k3", 8, 2, 1),
("ir_k3", 8, 1, 1, {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 16, 2, 1, {"expansion": 3}, IRF_CFG),
("ir_k5", 16, 1, 1, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 24, 1, 2, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 40, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 40, 1, 2, {"expansion": 3}, IRF_CFG),
("ir_k5", 64, 1, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 64, 1, 2, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 92, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 92, 1, 2, {"expansion": 4}, IRF_CFG),
("ir_k5", 92, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_A = [
# FBNetV3 arch without hs
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5_se", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3_se", 32, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 64, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 64, 1, 3, {"expansion": 3}, IRF_CFG),
("ir_k5_se", 112, 1, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 112, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5_se", 184, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3_se", 184, 1, 4, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 200, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_B = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 40, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 40, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 72, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 72, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 120, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 120, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 184, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 184, 1, 5, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 224, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_C = [
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 24, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5_se", 48, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 48, 1, 4, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5", 88, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 88, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 120, 1, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 120, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5_se", 216, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 216, 1, 5, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 216, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_D = [
[("conv_k3", 24, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k3", 24, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 24, 1, 5, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 40, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3_se", 40, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3", 72, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 72, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 128, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 128, 1, 6, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 208, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 208, 1, 5, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 240, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_E = [
[("conv_k3", 24, 2, 1), ("ir_k3", 16, 1, 3, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 4, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 48, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 48, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 80, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 80, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 128, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 128, 1, 7, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 216, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 216, 1, 5, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 240, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_F = [
[("conv_k3", 24, 2, 1), ("ir_k3", 24, 1, 3, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 32, 1, 4, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 56, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 56, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 88, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 88, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 144, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 144, 1, 8, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 248, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 248, 1, 6, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 272, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_G = [
[("conv_k3", 32, 2, 1), ("ir_k3", 24, 1, 3, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 40, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 40, 1, 4, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 56, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 56, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 104, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 104, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 160, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 160, 1, 8, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 264, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 264, 1, 6, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 288, 1, 2, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_H = [
[("conv_k3", 48, 2, 1), ("ir_k3", 32, 1, 4, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 64, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 64, 1, 6, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5_se", 80, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5_se", 80, 1, 6, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 160, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 160, 1, 6, {"expansion": 3}, IRF_CFG),
("ir_k3_se", 240, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 240, 1, 12, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3_se", 400, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5_se", 400, 1, 8, {"expansion": 5}, IRF_CFG),
("ir_k5_se", 480, 1, 3, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_A_no_se = [
# FBNetV3 without hs and SE (SE is not quantization friendly)
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 32, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 64, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 64, 1, 3, {"expansion": 3}, IRF_CFG),
("ir_k5", 112, 1, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 112, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 184, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k3", 184, 1, 4, {"expansion": 4}, IRF_CFG),
("ir_k5", 200, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
FBNetV3_B_no_se = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5", 40, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k5", 40, 1, 4, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 72, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 72, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3", 120, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5", 120, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3", 184, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5", 184, 1, 5, {"expansion": 4}, IRF_CFG),
("ir_k5", 224, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
# FBNetV3_B model, a lighter version for real-time inference
FBNetV3_B_light_no_se = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 2, {"expansion": 2}, IRF_CFG),
],
[
("ir_k5", 40, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k5", 40, 1, 3, {"expansion": 3}, IRF_CFG),
],
[
("ir_k5", 72, 2, 1, {"expansion": 5}, IRF_CFG),
("ir_k3", 72, 1, 4, {"expansion": 3}, IRF_CFG),
("ir_k3", 120, 1, 1, {"expansion": 5}, IRF_CFG),
("ir_k5", 120, 1, 5, {"expansion": 3}, IRF_CFG),
],
[
("ir_k3", 184, 2, 1, {"expansion": 6}, IRF_CFG),
("ir_k5", 184, 1, 5, {"expansion": 4}, IRF_CFG),
("ir_k5", 224, 1, 1, {"expansion": 6}, IRF_CFG),
],
]
LARGE_BOX_HEAD_STAGES = [
[("ir_k3", 160, 2, 1, e4), ("ir_k3", 160, 1, 2, e6), ("ir_k3", 240, 1, 1, e6)],
]
SMALL_BOX_HEAD_STAGES = [
[("ir_k3", 128, 2, 1, e4), ("ir_k3", 128, 1, 2, e6), ("ir_k3", 160, 1, 1, e6)],
]
TINY_BOX_HEAD_STAGES = [
[("ir_k3", 64, 2, 1, e4), ("ir_k3", 64, 1, 2, e4), ("ir_k3", 80, 1, 1, e4)],
]
LARGE_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 160, 1, 1, e4), ("ir_k3", 160, 1, 3, e6), ("ir_k3", 80, -2, 1, e3)],
]
LARGE_UPSAMPLE_HEAD_D21_STAGES = [
[("ir_k3", 192, 1, 1, e4), ("ir_k3", 192, 1, 5, e3), ("ir_k3", 96, -2, 1, e3)],
]
SMALL_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 128, 1, 1, e4), ("ir_k3", 128, 1, 3, e6), ("ir_k3", 64, -2, 1, e3)],
]
# NOTE: Compared with SMALL_UPSAMPLE_HEAD_STAGES, this does one more down-sample
# in the first "layer" and then up-sample twice
SMALL_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 128, 2, 1, e4), ("ir_k3", 128, 1, 2, e6), ("ir_k3", 128, -2, 1, e6), ("ir_k3", 64, -2, 1, e3)], # noqa
]
TINY_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 64, 2, 1, e4), ("ir_k3", 64, 1, 2, e4), ("ir_k3", 64, -2, 1, e4), ("ir_k3", 40, -2, 1, e3)], # noqa
]
FPN_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 96, 1, 1, e6), ("ir_k3", 160, 1, 3, e6), ("ir_k3", 80, -2, 1, e3)],
]
MODEL_ARCH_BUILTIN = {
"default": {
"trunk": DEFAULT_STAGES[0:4],
"rpn": [[_repeat_last(DEFAULT_STAGES[3])]],
"bbox": LARGE_BOX_HEAD_STAGES,
"mask": LARGE_UPSAMPLE_HEAD_STAGES,
"kpts": LARGE_UPSAMPLE_HEAD_STAGES,
"basic_args": _BASIC_ARGS,
},
"default_dsmask": {
"trunk": DEFAULT_STAGES[0:4],
"rpn": [[_repeat_last(DEFAULT_STAGES[3])]],
"bbox": SMALL_BOX_HEAD_STAGES,
"mask": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"kpts": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"basic_args": _BASIC_ARGS,
},
"FBNetV3_A": {
"trunk": FBNetV3_A[0:4],
"rpn": [[_repeat_last(FBNetV3_A[3])]],
"bbox": [FBNetV3_A[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_B": {
"trunk": FBNetV3_B[0:4],
"rpn": [[_repeat_last(FBNetV3_B[3])]],
"bbox": [FBNetV3_B[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_C": {
"trunk": FBNetV3_C[0:4],
"rpn": [[_repeat_last(FBNetV3_C[3])]],
"bbox": [FBNetV3_C[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_D": {
"trunk": FBNetV3_D[0:4],
"rpn": [[_repeat_last(FBNetV3_D[3])]],
"bbox": [FBNetV3_D[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_E": {
"trunk": FBNetV3_E[0:4],
"rpn": [[_repeat_last(FBNetV3_E[3])]],
"bbox": [FBNetV3_E[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_F": {
"trunk": FBNetV3_F[0:4],
"rpn": [[_repeat_last(FBNetV3_F[3])]],
"bbox": [FBNetV3_F[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_G": {
"trunk": FBNetV3_G[0:4],
"rpn": [[_repeat_last(FBNetV3_G[3])]],
"bbox": [FBNetV3_G[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_H": {
"trunk": FBNetV3_H[0:4],
"rpn": [[_repeat_last(FBNetV3_H[3])]],
"bbox": [FBNetV3_H[4]],
"basic_args": _BASIC_ARGS,
},
"FBNetV3_A_dsmask_C5": {
"trunk": FBNetV3_A_dsmask,
"rpn": [[_repeat_last(FBNetV3_A_dsmask[3])]],
"bbox": SMALL_BOX_HEAD_STAGES,
"mask": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"kpts": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"basic_args": _BASIC_ARGS,
},
"FBNetV3_A_dsmask": {
"trunk": FBNetV3_A_dsmask[0:4],
"rpn": [[_repeat_last(FBNetV3_A_dsmask[3])]],
"bbox": SMALL_BOX_HEAD_STAGES,
"mask": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"kpts": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"basic_args": _BASIC_ARGS,
},
"FBNetV3_A_dsmask_tiny": {
"trunk": FBNetV3_A_dsmask_tiny[0:4],
"rpn": [[_repeat_last(FBNetV3_A_dsmask_tiny[3])]],
"bbox": TINY_BOX_HEAD_STAGES,
"mask": TINY_DS_UPSAMPLE_HEAD_STAGES,
"kpts": TINY_DS_UPSAMPLE_HEAD_STAGES,
"basic_args": _BASIC_ARGS,
},
"FBNetV3_B_light_large": {
"trunk": FBNetV3_B_light_no_se[0:4],
"rpn": [[_repeat_last(FBNetV3_B_light_no_se[3])]],
"bbox": SMALL_BOX_HEAD_STAGES,
"mask": SMALL_DS_UPSAMPLE_HEAD_STAGES,
"kpts": LARGE_UPSAMPLE_HEAD_D21_STAGES,
"basic_args": _BASIC_ARGS,
},
"FBNetV3_G_fpn": {
"trunk": FBNetV3_G[0:5], # FPN uses all 5 stages
"rpn": [[_repeat_last(FBNetV3_G[3], n=1)]],
"bbox": [FBNetV3_G[4]],
"mask": FPN_UPSAMPLE_HEAD_STAGES,
"kpts": LARGE_UPSAMPLE_HEAD_D21_STAGES,
"basic_args": _BASIC_ARGS,
},
}
FBNetV2ModelArch.add_archs(MODEL_ARCH_BUILTIN)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import inspect
import logging
import torch
import torch.quantization.quantize_fx
from detectron2.checkpoint import DetectionCheckpointer
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
logger = logging.getLogger(__name__)
def _is_observer_key(state_dict_key):
observer_keys = ["activation_post_process", "weight_fake_quant"]
return any(x in state_dict_key for x in observer_keys)
class QATCheckpointer(DetectionCheckpointer):
"""
Extend the Checkpointer to support loading (QAT / non-QAT) weight into
(QAT / non-QAT) model.
"""
@classmethod
def _is_q_state_dict(cls, state_dict):
return any(_is_observer_key(k) for k in state_dict)
def _load_model(self, checkpoint):
model_is_qat = self._is_q_state_dict(self.model.state_dict())
checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"])
if model_is_qat and not checkpoint_is_qat:
logger.info("Loading QAT model with non-QAT checkpoint, ignore observers!")
mapping = getattr(self.model, "_non_qat_to_qat_state_dict_map", {})
# map the key from non-QAT model to QAT model if possible
checkpoint_state_dict = {
mapping.get(k, k): v for k, v in checkpoint["model"].items()
}
checkpoint["model"] = checkpoint_state_dict
incompatible = super()._load_model(checkpoint)
# suppress the missing observer keys warning
# NOTE: for some reason incompatible.missing_keys can have duplicated keys,
# here we replace the entire list rather than calling .remove()
missing_non_qat_keys = [
k for k in incompatible.missing_keys if not _is_observer_key(k)
]
incompatible.missing_keys[:] = missing_non_qat_keys
return incompatible
elif not model_is_qat and checkpoint_is_qat:
raise NotImplementedError()
elif model_is_qat and checkpoint_is_qat:
# TODO: maybe suppress shape mismatch
# For models trained with QAT and per-channel quant, the inital size of the
# buffers in fake_quant and observer modules does not reflect the size in
# state_dict, which is updated only when convert is called.
return super()._load_model(checkpoint)
else:
return super()._load_model(checkpoint)
def add_quantization_default_configs(_C):
CfgNode = type(_C)
_C.QUANTIZATION = CfgNode()
# Note: EAGER_MODE == False currently represents FX graph mode quantization
_C.QUANTIZATION.EAGER_MODE = True
_C.QUANTIZATION.BACKEND = "fbgemm"
# used to enable metarch set_custom_qscheme (need to implement)
# this is a limited implementation where only str is provided to change options
_C.QUANTIZATION.CUSTOM_QSCHEME = ""
# quantization-aware training
_C.QUANTIZATION.QAT = CfgNode()
_C.QUANTIZATION.QAT.ENABLED = False
# QAT will use more GPU memory, user can change this factor to reduce the batch size
# after fake quant is enabled. Setting it to 0.5 should guarantee no memory increase
# compared with QAT is disabled.
_C.QUANTIZATION.QAT.BATCH_SIZE_FACTOR = 1.0
# the iteration number to start QAT, (i.e. enable fake quant). The default value of
# SOLVER.MAX_ITER is 40k and SOLVER.STEPS is (30k,), here we turn on QAT at 35k, so
# the last 5k iterations will run with QAT with decreased learning rate.
_C.QUANTIZATION.QAT.START_ITER = 35000
# the iteration number to enable observer, it's usually set to be the same as
# QUANTIZATION.QAT.START_ITER.
_C.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = 35000
# the iteration number to disable observer, here it's 3k after enabling the fake
# quant, 3k roughly corresponds to 7 out of 90 epochs in classification.
_C.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = 35000 + 3000
# the iteration number to freeze BN, here it's 3k after enabling the fake quant, 2k
# roughly corresponds to 5 out of 90 epochs for classification.
_C.QUANTIZATION.QAT.FREEZE_BN_ITER = 35000 + 2000
# qat hook will run observers update_stat if it exists
# after update_observer_stats_period iters
_C.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY = False
_C.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD = 1
# post-training quantization
_C.QUANTIZATION.PTQ = CfgNode()
_C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 1
_C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False
# deprecated
_C.QUANTIZATION.SILICON_QAT = CfgNode()
_C.QUANTIZATION.SILICON_QAT.ENABLED = False
# register deprecated and renamed keys
_C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED")
_C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND")
_C.register_deprecated_key("QUANTIZATION.ENABLE_CUSTOM_QSCHEME")
@contextlib.contextmanager
def silicon_qat_build_model_context(cfg):
mock_ctx_managers = []
if cfg.QUANTIZATION.SILICON_QAT.ENABLED:
from mobile_cv.silicon_pytorch_qat.replace_op import mock_quant_ops
mock_ctx_managers.extend(
[
mock_quant_ops(quant_op="quant_add"),
mock_quant_ops(quant_op="quant_fbb_convbnrelu"),
]
)
with contextlib.ExitStack() as stack:
for mgr in mock_ctx_managers:
stack.enter_context(mgr)
yield
# TODO: model.to(device) might not work for detection meta-arch, this function is the
# workaround, in general, we might need a meta-arch API for this if needed.
def _cast_detection_model(model, device):
# check model is an instance of one of the meta arch
from detectron2.export.caffe2_modeling import Caffe2MetaArch
from detectron2.modeling import META_ARCH_REGISTRY
if isinstance(model, Caffe2MetaArch):
model._wrapped_model = _cast_detection_model(model._wrapped_model, device)
return model
assert isinstance(model, tuple(META_ARCH_REGISTRY._obj_map.values()))
model.to(device)
# cast normalizer separately
if hasattr(model, "normalizer") and not (
hasattr(model, "pixel_mean") and hasattr(model, "pixel_std")
):
pixel_mean = inspect.getclosurevars(model.normalizer).nonlocals["pixel_mean"]
pixel_std = inspect.getclosurevars(model.normalizer).nonlocals["pixel_std"]
pixel_mean = pixel_mean.to(device)
pixel_std = pixel_std.to(device)
model.normalizer = lambda x: (x - pixel_mean) / pixel_std
return model
def add_d2_quant_mapping(mappings):
""" HACK: Add d2 specific module mapping for eager model quantization
"""
import torch.quantization.quantization_mappings as qm
for k, v in mappings.items():
if k not in qm.get_default_static_quant_module_mappings():
qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v
if k not in qm.get_default_qat_module_mappings():
qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
# The `mock_quantization_type` decorate may not be needed anymore to unify
# detectron2.layers modules and torch.nn modules since Pytorch 1.5. See comments on D23790034.
def mock_quantization_type(quant_func):
import mock
import builtins
import functools
import detectron2.layers as d2l
type_mapping = {d2l.Linear: torch.nn.Linear}
from d2go.utils.misc import check_version
if check_version(torch, '1.7.2', warning_only=True):
add_d2_quant_mapping(type_mapping)
real_type = builtins.type
def _new_type(obj):
rtype = real_type(obj)
return type_mapping.get(rtype, rtype)
@functools.wraps(quant_func)
def wrapper(cfg, model, *args, **kwargs):
if type(d2l.Linear) == torch.nn.Linear:
# we do not need the moc after when the type is expected, consider
# remving those related code
logger.warning(
"`detectron2.layers.Linear` is in expected type (torch.nn.Linear),"
"consider removing this code `mock_quantization_type`."
)
return quant_func(cfg, model, *args, **kwargs)
if not cfg.QUANTIZATION.EAGER_MODE:
return quant_func(cfg, model, *args, **kwargs)
# `from_float()` in `torch.nn.quantized.modules.linear.Linear` and
# `torch.nn.qat.modules.linear` checkes if the type of `mod` is torch.Linear,
# hack it to return the expected value
with mock.patch("torch.nn.quantized.modules.linear.type") as mock_type:
with mock.patch("torch.nn.qat.modules.linear.type") as mock_type2:
mock_type.side_effect = _new_type
mock_type2.side_effect = _new_type
return quant_func(cfg, model, *args, **kwargs)
return wrapper
def default_prepare_for_quant(cfg, model):
"""
Default implementation of preparing a model for quantization. This function will
be called to before training if QAT is enabled, or before calibration during PTQ if
the model is not already quantized.
NOTE:
- This is the simplest implementation, most meta-arch needs its own version.
- For eager model, user should make sure the returned model has Quant/DeQuant
insert. This can be done by wrapping the model or defining the model with
quant stubs.
- QAT/PTQ can be determined by model.training.
- Currently the input model can be changed inplace since we won't re-use the
input model.
- Currently this API doesn't include the final torch.quantization.prepare(_qat)
call since existing usecases don't have further steps after it.
Args:
model (nn.Module): a non-quantized model.
cfg (CfgNode): config
Return:
nn.Module: a ready model for QAT training or PTQ calibration
"""
qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if cfg.QUANTIZATION.EAGER_MODE:
model = fuse_utils.fuse_model(model, inplace=True)
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = qconfig
# TODO(future diff): move the torch.quantization.prepare(...) call
# here, to be consistent with the FX branch
else: # FX graph mode quantization
qconfig_dict = {"": qconfig}
if model.training:
model = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
else:
model = torch.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
@mock_quantization_type
def post_training_quantize(cfg, model, data_loader):
""" Calibrate a model, convert it to a quantized pytorch model """
model = copy.deepcopy(model)
model.eval()
# TODO: check why some parameters will have gradient
for param in model.parameters():
param.requires_grad = False
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
if cfg.QUANTIZATION.EAGER_MODE:
torch.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
# casting both model and inputs.
calibration_force_on_gpu = (
cfg.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU and torch.cuda.is_available()
)
if calibration_force_on_gpu:
# NOTE: model.to(device) may not handle cases such as normalizer, FPN, only
# do move to GPU if specified.
_cast_detection_model(model, "cuda")
calibration_iters = cfg.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES
for idx, inputs in enumerate(data_loader):
logger.info("Running calibration iter: {}/{}".format(idx, calibration_iters))
if calibration_force_on_gpu:
iters = recursive_iterate(inputs)
for x in iters:
if isinstance(x, torch.Tensor):
iters.send(x.to("cuda"))
inputs = iters.value
with torch.no_grad():
model(inputs)
if idx + 1 == calibration_iters:
break
else:
logger.warning("Can't run enough calibration iterations")
# cast model back to the original device
if calibration_force_on_gpu:
_cast_detection_model(model, cfg.MODEL.DEVICE)
return model
@mock_quantization_type
def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
if hasattr(model, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
original_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()}
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
# TODO(future diff): move this into prepare_for_quant to match FX branch
torch.quantization.prepare_qat(model, inplace=True)
else: # FX graph mode quantization
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
# Move newly added observers to the original device
model.to(device)
if not enable_fake_quant:
logger.info("Disabling fake quant ...")
model.apply(torch.quantization.disable_fake_quant)
if not enable_observer:
logger.info("Disabling observer ...")
model.apply(torch.quantization.disable_observer)
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model.
new_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()}
new_state_dict_non_observer_keys = [
k for k in new_state_dict_shapes if not _is_observer_key(k)
]
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
if cfg.QUANTIZATION.EAGER_MODE:
for n_k, o_k in zip(new_state_dict_non_observer_keys, original_state_dict_shapes):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
# _q_state_dict_map will store
model._non_qat_to_qat_state_dict_map = dict(
zip(original_state_dict_shapes, new_state_dict_non_observer_keys)
)
else:
# in FX, the order of where modules appear in the state_dict may change,
# so we need to match by key
def get_new_bn_key(old_bn_key):
# tries to adjust the key for conv-bn fusion, where
# root
# - conv
# - bn
#
# becomes
#
# root
# - conv
# - bn
return old_bn_key.replace(".bn.", ".conv.bn.")
model._non_qat_to_qat_state_dict_map = {}
for key in original_state_dict_shapes.keys():
if key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = key
else:
maybe_new_bn_key = get_new_bn_key(key)
if maybe_new_bn_key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key
return model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import (
D2GO_DATA_MAPPER_REGISTRY,
D2GoDatasetMapper,
)
from d2go.utils.helper import alias
def add_subclass_configs(cfg):
_C = cfg
_C.MODEL.SUBCLASS = CN()
_C.MODEL.SUBCLASS.SUBCLASS_ON = False
_C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set
def fetch_subclass_from_extras(dataset_dict):
"""
Retrieve subclass (eg. hand gesture per RPN region) info from dataset dict.
"""
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list]
return subclass_ids
@D2GO_DATA_MAPPER_REGISTRY.register()
class SubclassDatasetMapper(D2GoDatasetMapper):
"""
Wrap any dataset mapper, encode gt_subclasses to the instances.
"""
def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None):
super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens)
self.subclass_fetcher = subclass_fetcher or fetch_subclass_from_extras
# NOTE: field doesn't exist when loading a (old) caffe2 model.
# self.subclass_on = cfg.MODEL.SUBCLASS.SUBCLASS_ON
self.subclass_on = True
def _original_call(self, dataset_dict):
"""
Map the dataset dict with D2GoDatasetMapper, then augment with subclass gt tensors.
"""
# Transform removes key 'annotations' from the dataset dict
mapped_dataset_dict = super()._original_call(dataset_dict)
if (self.is_train and self.subclass_on):
subclass_ids = self.subclass_fetcher(dataset_dict)
subclasses = torch.tensor(subclass_ids, dtype=torch.int64)
mapped_dataset_dict["instances"].gt_subclasses = subclasses
return mapped_dataset_dict
@ROI_HEADS_REGISTRY.register()
class StandardROIHeadsWithSubClass(StandardROIHeads):
"""
A Standard ROIHeads which contains an addition of subclass head.
"""
def __init__(self, cfg, input_shape):
super().__init__(cfg, input_shape)
self.subclass_on = cfg.MODEL.SUBCLASS.SUBCLASS_ON
if not self.subclass_on:
return
self.subclass_head = nn.Linear(
self.box_head.output_shape.channels, cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1
)
nn.init.normal_(self.subclass_head.weight, std=0.01)
nn.init.constant_(self.subclass_head.bias, 0.0)
def forward(self, images, features, proposals, targets=None):
"""
Same as StandardROIHeads.forward but add logic for subclass.
"""
if not self.subclass_on:
return super().forward(images, features, proposals, targets)
# --- start copy -------------------------------------------------------
del images
if self.training:
proposals = self.label_and_sample_proposals(proposals, targets)
# NOTE: `has_gt` = False for negatives and we must manually register `gt_subclasses`,
# because custom gt_* fields will not be automatically registered in sampled proposals.
for pp_per_im in proposals:
if not pp_per_im.has("gt_subclasses"):
background_subcls_idx = 0
pp_per_im.gt_subclasses = torch.cuda.LongTensor(len(pp_per_im)).fill_(background_subcls_idx)
del targets
features_list = [features[f] for f in self.in_features]
box_features = self.box_pooler(features_list, [x.proposal_boxes for x in proposals])
box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features)
# --- end copy ---------------------------------------------------------
# NOTE: don't delete box_features, keep it temporarily
# del box_features
box_features = box_features.view(
box_features.shape[0],
np.prod(box_features.shape[1:])
)
pred_subclass_logits = self.subclass_head(box_features)
if self.training:
losses = self.box_predictor.losses(predictions, proposals)
# During training the proposals used by the box head are
# used by the mask, keypoint (and densepose) heads.
losses.update(self._forward_mask(features, proposals))
losses.update(self._forward_keypoint(features, proposals))
# subclass head
gt_subclasses = cat([p.gt_subclasses for p in proposals], dim=0)
loss_subclass = F.cross_entropy(
pred_subclass_logits, gt_subclasses, reduction="mean"
)
losses.update({"loss_subclass": loss_subclass})
return proposals, losses
else:
pred_instances, kept_indices = self.box_predictor.inference(
predictions, proposals
)
# During inference cascaded prediction is used: the mask and keypoints
# heads are only applied to the top scoring box detections.
pred_instances = self.forward_with_given_boxes(features, pred_instances)
# subclass head
probs = F.softmax(pred_subclass_logits, dim=-1)
for pred_instances_i, kept_indices_i in zip(pred_instances, kept_indices):
pred_instances_i.pred_subclass_prob = torch.index_select(
probs,
dim=0,
index=kept_indices_i.to(torch.int64),
)
if torch.onnx.is_in_onnx_export():
assert len(pred_instances) == 1
pred_instances[0].pred_subclass_prob = alias(
pred_instances[0].pred_subclass_prob,
"subclass_prob_nms"
)
return pred_instances, {}
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .build import build_optimizer_mapper
__all__ = ['build_optimizer_mapper']
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import itertools
from typing import Any, Dict, List, Optional, Set
from detectron2.utils.registry import Registry
from detectron2.solver.build import maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def get_default_optimizer_params(
model: torch.nn.Module,
base_lr,
weight_decay,
weight_decay_norm,
bias_lr_factor=1.0,
weight_decay_bias=None,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
lr_multipliers_overwrite: Optional[Dict[str, float]] = None,
):
"""
Get default param list for optimizer
Args:
overrides (dict: str -> (dict: str -> float)):
if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
{"embedding": {"lr": 0.01, "weight_decay": 0.1}} will set the LR and
weight decay values for all module parameters named `embedding` (default: None)
lr_multipliers_overwrite (dict: str-> float):
Applying different lr multiplier to a set of parameters whose names
containing certain keys. For example, if lr_multipliers_overwrite={'backbone': 0.1},
the LR for the parameters whose names containing 'backbone' will be scaled to 0.1x.
Set lr_multipliers_overwrite={} if no multipliers required.
"""
if weight_decay_bias is None:
weight_decay_bias = weight_decay
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
schedule_params = {
"lr": base_lr,
"weight_decay": weight_decay,
}
if isinstance(module, norm_module_types):
schedule_params["weight_decay"] = weight_decay_norm
elif module_param_name == "bias":
# NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
# and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
# hyperparameters are by default exactly the same as for regular
# weights.
schedule_params["lr"] = base_lr * bias_lr_factor
schedule_params["weight_decay"] = weight_decay_bias
if overrides is not None and module_param_name in overrides:
schedule_params.update(overrides[module_param_name])
if lr_multipliers_overwrite is not None:
for kname, mult in lr_multipliers_overwrite.items():
if kname in module_param_name:
# apply multiplier for the params containing kname, e.g. backbone
schedule_params['lr'] = schedule_params['lr'] * mult
params += [
{
"params": [value],
"lr": schedule_params["lr"],
"weight_decay": schedule_params["weight_decay"],
}
]
return params
def maybe_add_gradient_clipping(cfg, optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
if enable:
return FullModelGradientClippingOptimizer
return d2_maybe_add_gradient_clipping(cfg, optim)
def _merge_dict(in_dict):
ret_dict = {}
assert all(isinstance(x, dict) for x in in_dict)
for dic in in_dict:
ret_dict.update(dic)
return ret_dict
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR)
def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER
return D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import importlib
from typing import Type
from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
def get_class(class_full_name: str) -> Type:
"""Imports and returns the task class."""
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
return runner_class
def create_runner(class_full_name: str, *args, **kwargs) -> BaseRunner:
"""Constructs a runner instance of the given class."""
runner_class = get_class(class_full_name)
return runner_class(*args, **kwargs)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import random
import torch
import torch.nn as nn
from d2go.modeling.quantization import QATCheckpointer
from d2go.runner.default_runner import (
BaseRunner,
add_tensorboard_default_configs,
)
from fvcore.common.file_io import PathManager
class DebugRunner(BaseRunner):
def get_default_cfg(self):
_C = super().get_default_cfg()
# _C.TENSORBOARD...
add_tensorboard_default_configs(_C)
# target metric
_C.TEST.TARGET_METRIC = "dataset0:dummy0:metric1"
return _C
def build_model(self, cfg, eval_only=False):
return nn.Sequential()
def do_test(self, cfg, model, train_iter=None):
return {
"dataset0": {
"dummy0": {"metric0": random.random(), "metric1": random.random()}
}
}
def do_train(self, cfg, model, resume):
# save a dummy checkpoint file
save_file = os.path.join(cfg.OUTPUT_DIR, "model_123.pth")
with PathManager.open(save_file, "wb") as f:
torch.save({"model": model.state_dict()}, f)
save_file = os.path.join(cfg.OUTPUT_DIR, "model_12345.pth")
with PathManager.open(save_file, "wb") as f:
torch.save({"model": model.state_dict()}, f)
save_file = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
with PathManager.open(save_file, "wb") as f:
torch.save({"model": model.state_dict()}, f)
def build_checkpointer(self, cfg, model, save_dir, **kwargs):
checkpointer = QATCheckpointer(model, save_dir=save_dir, **kwargs)
return checkpointer
@staticmethod
def final_model_name():
return "model_final"
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import logging
import math
import os
from collections import OrderedDict
from functools import lru_cache, partial
from typing import Type
import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm
import mock
import torch
from d2go.config import CfgNode as CN, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost, get_cfg_diff_table
from d2go.data.build import (
build_weighted_detection_train_loader,
)
from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.transforms.build import build_transform_gen
from d2go.data.utils import (
maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset,
)
from d2go.export.caffe2_model_helper import update_cfg_from_pb_model
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.modeling.quantization import (
QATCheckpointer,
setup_qat_model,
silicon_qat_build_model_context,
)
from d2go.utils.flop_calculator import add_print_flops_callback
from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator
from d2go.utils.get_default_cfg import get_default_cfg
from d2go.optimizer import build_optimizer_mapper
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.data import (
build_detection_test_loader as d2_build_detection_test_loader,
build_detection_train_loader as d2_build_detection_train_loader,
MetadataCatalog,
)
from detectron2.engine import SimpleTrainer, AMPTrainer, hooks
from detectron2.evaluation import (
COCOEvaluator,
RotatedCOCOEvaluator,
DatasetEvaluators,
inference_on_dataset,
print_csv_format,
verify_results,
)
from detectron2.export.caffe2_inference import ProtobufDetectionModel
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
from d2go.utils.helper import TensorboardXWriter, D2Trainer
from detectron2.modeling import GeneralizedRCNNWithTTA, build_model
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
)
from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def _mock_func(module, src_func, target_func):
with mock.patch(
"{}.{}".format(module.__name__, src_func.__name__), side_effect=target_func
) as mocked_func:
yield
if not mocked_func.call_count >= 1:
logger.warning("Didn't patch the {} in module {}".format(src_func, module))
ALL_TB_WRITERS = []
@lru_cache()
def _get_tbx_writer(log_dir):
ret = TensorboardXWriter(log_dir)
ALL_TB_WRITERS.append(ret)
return ret
def _close_all_tbx_writers():
for x in ALL_TB_WRITERS:
x.close()
ALL_TB_WRITERS.clear()
@CONFIG_SCALING_METHOD_REGISTRY.register()
def default_scale_d2_configs(cfg, new_world_size):
gpu_scale = new_world_size / cfg.SOLVER.REFERENCE_WORLD_SIZE
base_lr = cfg.SOLVER.BASE_LR
max_iter = cfg.SOLVER.MAX_ITER
steps = cfg.SOLVER.STEPS
eval_period = cfg.TEST.EVAL_PERIOD
ims_per_batch_train = cfg.SOLVER.IMS_PER_BATCH
warmup_iters = cfg.SOLVER.WARMUP_ITERS
# lr scale
lr_scales = {
"sgd": gpu_scale,
"adamw": 1,
}
optim_name = cfg.SOLVER.OPTIMIZER.lower()
lr_scale = lr_scales[optim_name] if optim_name in lr_scales else gpu_scale
# default configs in D2
cfg.SOLVER.BASE_LR = base_lr * lr_scale
cfg.SOLVER.MAX_ITER = int(round(max_iter / gpu_scale))
cfg.SOLVER.STEPS = tuple(int(round(s / gpu_scale)) for s in steps)
cfg.TEST.EVAL_PERIOD = int(round(eval_period / gpu_scale))
cfg.SOLVER.IMS_PER_BATCH = int(round(ims_per_batch_train * gpu_scale))
cfg.SOLVER.WARMUP_ITERS = int(round(warmup_iters / gpu_scale))
@CONFIG_SCALING_METHOD_REGISTRY.register()
def default_scale_quantization_configs(cfg, new_world_size):
gpu_scale = new_world_size / cfg.SOLVER.REFERENCE_WORLD_SIZE
# Scale QUANTIZATION related configs
cfg.QUANTIZATION.QAT.START_ITER = int(
round(cfg.QUANTIZATION.QAT.START_ITER / gpu_scale)
)
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = int(
round(cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER / gpu_scale)
)
cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = int(
round(cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER / gpu_scale)
)
cfg.QUANTIZATION.QAT.FREEZE_BN_ITER = int(
round(cfg.QUANTIZATION.QAT.FREEZE_BN_ITER / gpu_scale)
)
class BaseRunner(object):
def _initialize(self, cfg):
""" Runner should be initialized in the sub-process in ddp setting """
if getattr(self, "_has_initialized", False):
logger.warning("Runner has already been initialized, skip initialization.")
return
self._has_initialized = True
self.register(cfg)
def register(self, cfg):
"""
Override `register` in order to run customized code before other things like:
- registering datasets.
- registering model using Registry.
"""
pass
@staticmethod
def get_default_cfg():
"""
Override `get_default_cfg` for adding non common config.
"""
from detectron2.config import get_cfg as get_d2_cfg
cfg = get_d2_cfg()
cfg = CN(cfg) # upgrade from D2's CfgNode to D2Go's CfgNode
cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"]
return cfg
def build_model(self, cfg, eval_only=False):
# cfg may need to be reused to build trace model again, thus clone
model = build_model(cfg.clone())
if eval_only:
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(cfg.MODEL.WEIGHTS)
model.eval()
return model
def build_traceable_model(self, cfg, built_model=None):
"""
Return a traceable model. The returned model has to be a
`Caffe2MetaArch` which provides the following two member methods:
- get_caffe2_inputs: it'll be called when exporting the model
to convert D2's batched_input to a list of Tensors.
- encode_additional_info: this allow editing exported predict_net/init_net.
"""
return built_model
def build_caffe2_model(self, predict_net, init_net):
"""
Return a nn.Module which should behave the same as a normal D2 model.
"""
raise NotImplementedError()
def do_test(self, *args, **kwargs):
raise NotImplementedError()
def do_train(self, *args, **kwargs):
raise NotImplementedError()
@classmethod
def build_detection_test_loader(cls, *args, **kwargs):
return d2_build_detection_test_loader(*args, **kwargs)
@classmethod
def build_detection_train_loader(cls, *args, **kwargs):
return d2_build_detection_train_loader(*args, **kwargs)
class Detectron2GoRunner(BaseRunner):
def register(self, cfg):
super().register(cfg)
self.original_cfg = cfg.clone()
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
@staticmethod
def get_default_cfg():
_C = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg()
return get_default_cfg(_C)
def build_model(self, cfg, eval_only=False):
# build_model might modify the cfg, thus clone
cfg = cfg.clone()
# silicon_qat_build_model_context is deprecated
with silicon_qat_build_model_context(cfg):
model = build_model(cfg)
model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
if cfg.QUANTIZATION.QAT.ENABLED:
# Disable fake_quant and observer so that the model will be trained normally
# before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
model = setup_qat_model(
cfg, model, enable_fake_quant=eval_only, enable_observer=False
)
if eval_only:
checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(cfg.MODEL.WEIGHTS)
model.eval()
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
model_ema.apply_model_ema(model)
return model
def build_checkpointer(self, cfg, model, save_dir, **kwargs):
kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model))
checkpointer = QATCheckpointer(model, save_dir=save_dir, **kwargs)
return checkpointer
def build_optimizer(self, cfg, model):
return build_optimizer_mapper(cfg, model)
def build_lr_scheduler(self, cfg, optimizer):
return d2_build_lr_scheduler(cfg, optimizer)
def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST)
assert cfg.OUTPUT_DIR
is_final = (train_iter is None) or (train_iter == cfg.SOLVER.MAX_ITER - 1)
logger.info(
f"Running evaluation for model tag {model_tag} at iter {train_iter}..."
)
def _get_inference_dir_name(base_dir, inference_type, dataset_name):
return os.path.join(
base_dir,
inference_type,
model_tag,
str(train_iter) if train_iter is not None else "final",
dataset_name,
)
add_print_flops_callback(cfg, model, disable_after_callback=True)
results = OrderedDict()
results[model_tag] = OrderedDict()
for dataset_name in cfg.DATASETS.TEST:
# Evaluator will create output folder, no need to create here
output_folder = _get_inference_dir_name(
cfg.OUTPUT_DIR, "inference", dataset_name
)
# NOTE: creating evaluator after dataset is loaded as there might be dependency. # noqa
data_loader = self.build_detection_test_loader(cfg, dataset_name)
evaluator = self.get_evaluator(
cfg, dataset_name, output_folder=output_folder
)
if not isinstance(evaluator, DatasetEvaluators):
evaluator = DatasetEvaluators([evaluator])
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False)
evaluator._evaluators.append(
self.get_visualization_evaluator()(
cfg,
tbx_writer,
mapper,
dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
)
)
results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
if comm.is_main_process():
results[model_tag][dataset_name] = results_per_dataset
if is_final:
print_csv_format(results_per_dataset)
if is_final and cfg.TEST.AUG.ENABLED:
# In the end of training, run an evaluation with TTA
# Only support some R-CNN models.
output_folder = _get_inference_dir_name(
cfg.OUTPUT_DIR, "inference_TTA", dataset_name
)
logger.info("Running inference with test-time augmentation ...")
data_loader = self.build_detection_test_loader(
cfg, dataset_name, mapper=lambda x: x
)
evaluator = self.get_evaluator(
cfg, dataset_name, output_folder=output_folder
)
inference_on_dataset(
GeneralizedRCNNWithTTA(cfg, model), data_loader, evaluator
)
if is_final and cfg.TEST.EXPECTED_RESULTS and comm.is_main_process():
assert len(results) == 1, "Results verification only supports one dataset!"
verify_results(cfg, results[model_tag][cfg.DATASETS.TEST[0]])
# write results to tensorboard
if comm.is_main_process() and results:
from detectron2.evaluation.testing import flatten_results_dict
flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
tbx_writer._writer.add_scalar("eval_{}".format(k), v, train_iter)
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
tbx_writer._writer.flush()
return results
def do_test(self, cfg, model, train_iter=None):
results = OrderedDict()
with maybe_subsample_n_images(cfg) as new_cfg:
# default model
cur_results = self._do_test(
new_cfg, model, train_iter=train_iter, model_tag="default"
)
results.update(cur_results)
# model with ema weights
if cfg.MODEL_EMA.ENABLED:
logger.info("Run evaluation with EMA.")
with model_ema.apply_model_ema_and_restore(model):
cur_results = self._do_test(
new_cfg, model, train_iter=train_iter, model_tag="ema"
)
results.update(cur_results)
return results
def do_train(self, cfg, model, resume):
add_print_flops_callback(cfg, model, disable_after_callback=True)
optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer)
checkpointer = self.build_checkpointer(
cfg,
model,
save_dir=cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=scheduler,
)
checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
start_iter = (
checkpoint.get("iteration", -1)
if resume and checkpointer.has_checkpoint()
else -1
)
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
start_iter += 1
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
data_loader = self.build_detection_train_loader(cfg)
def _get_model_with_abnormal_checker(model):
if not cfg.ABNORMAL_CHECKER.ENABLED:
return model
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
writers = abnormal_checker.get_writers(cfg, tbx_writer)
checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
return ret
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
trainer_hooks = [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
writers = [
CommonMetricPrinter(max_iter),
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
tbx_writer,
]
trainer_hooks.append(hooks.PeriodicWriter(writers))
trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter)
if hasattr(self, 'original_cfg'):
table = get_cfg_diff_table(cfg, self.original_cfg)
logger.info("GeneralizeRCNN Runner ignoring training config change: \n" + table)
trained_cfg = self.original_cfg.clone()
else:
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
return {"model_final": trained_cfg}
@classmethod
def build_detection_test_loader(cls, cfg, dataset_name, mapper=None):
logger.info(
"Building detection test loader for dataset: {} ...".format(dataset_name)
)
mapper = mapper or cls.get_mapper(cfg, is_train=False)
logger.info("Using dataset mapper:\n{}".format(mapper))
return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
logger.info("Building detection train loader ...")
mapper = mapper or cls.get_mapper(cfg, is_train=True)
logger.info("Using dataset mapper:\n{}".format(mapper))
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
if sampler_name == "WeightedTrainingSampler":
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
else:
data_loader = d2_build_detection_train_loader(
cfg, *args, mapper=mapper, **kwargs
)
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
data_loader = cls.get_data_loader_vis_wrapper()(
cfg, tbx_writer, data_loader
)
return data_loader
@staticmethod
def get_data_loader_vis_wrapper() -> Type[DataLoaderVisWrapper]:
return DataLoaderVisWrapper
@staticmethod
def get_evaluator(cfg, dataset_name, output_folder):
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type in ["coco", "coco_panoptic_seg"]:
# D2 is in the process of reducing the use of cfg.
dataset_evaluators = COCOEvaluator(
dataset_name,
output_dir=output_folder,
kpt_oks_sigmas=cfg.TEST.KEYPOINT_OKS_SIGMAS,
)
elif evaluator_type in ["rotated_coco"]:
dataset_evaluators = DatasetEvaluators(
[RotatedCOCOEvaluator(dataset_name, cfg, True, output_folder)]
)
else:
dataset_evaluators = D2Trainer.build_evaluator(
cfg, dataset_name, output_folder
)
if not isinstance(dataset_evaluators, DatasetEvaluators):
dataset_evaluators = DatasetEvaluators([dataset_evaluators])
return dataset_evaluators
@staticmethod
def get_mapper(cfg, is_train):
tfm_gens = build_transform_gen(cfg, is_train)
mapper = build_dataset_mapper(cfg, is_train, tfm_gens=tfm_gens)
return mapper
@staticmethod
def get_visualization_evaluator() -> Type[VisualizationEvaluator]:
return VisualizationEvaluator
@staticmethod
def final_model_name():
return "model_final"
def _create_after_step_hook(
self, cfg, model, optimizer, scheduler, periodic_checkpointer
):
"""
Create a hook that performs some pre-defined tasks used in this script
(evaluation, LR scheduling, checkpointing).
"""
def after_step_callback(trainer):
trainer.storage.put_scalar(
"lr", optimizer.param_groups[0]["lr"], smoothing_hint=False
)
scheduler.step()
# Note: when precise BN is enabled, some checkpoints will have more precise
# statistics than others, if they are saved immediately after eval.
if comm.is_main_process():
periodic_checkpointer.step(trainer.iter)
return hooks.CallbackHook(after_step=after_step_callback)
def _create_qat_hook(self, cfg):
"""
Create a hook to start QAT (during training) and/or change the phase of QAT.
"""
applied = {
"enable_fake_quant": False,
"enable_observer": False,
"disable_observer": False,
"freeze_bn_stats": False,
}
assert (
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
<= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
), "Can't diable observer before enabling it"
def qat_before_step_callback(trainer):
if (
not applied["enable_fake_quant"]
and trainer.iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.quantization.enable_fake_quant)
applied["enable_fake_quant"] = True
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
loader_cfg = cfg.clone()
loader_cfg.defrost()
num_gpus = comm.get_world_size()
old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
loader_cfg.freeze()
logger.info(
"[QAT] Rebuild data loader with batch size per GPU: {} -> {}".format(
old_bs, new_bs
)
)
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
del trainer._data_loader_iter
del trainer.data_loader
data_loader = self.build_detection_train_loader(loader_cfg)
trainer.data_loader = data_loader
trainer._data_loader_iter = iter(data_loader)
if (
not applied["enable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(trainer.iter))
trainer.model.apply(torch.quantization.enable_observer)
applied["enable_observer"] = True
if (
not applied["disable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.quantization.disable_observer)
applied["disable_observer"] = True
if (
not applied["freeze_bn_stats"]
and trainer.iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(trainer.iter)
)
trainer.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
applied["freeze_bn_stats"] = True
if (
applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and trainer.iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD
== 0
):
logger.info(f"[QAT] updating observers, iter = {trainer.iter}")
trainer.model.apply(observer_update_stat)
return hooks.CallbackHook(before_step=qat_before_step_callback)
class GeneralizedRCNNRunner(Detectron2GoRunner):
@staticmethod
def get_default_cfg():
_C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg()
_C.EXPORT_CAFFE2 = CN()
_C.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
return _C
def build_traceable_model(self, cfg, built_model=None):
if built_model is not None:
logger.warning("The given built_model will be modified")
else:
built_model = self.build_model(cfg, eval_only=True)
logger.info("Model:\n{}".format(built_model))
Caffe2ModelType = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
return Caffe2ModelType(cfg, torch_model=built_model)
def build_caffe2_model(self, predict_net, init_net):
pb_model = ProtobufDetectionModel(predict_net, init_net)
pb_model.validate_cfg = partial(update_cfg_from_pb_model, model=pb_model)
return pb_model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl
import torch
from d2go.config import CfgNode
from d2go.runner.default_runner import (
Detectron2GoRunner,
GeneralizedRCNNRunner,
)
from d2go.setup import setup_after_launch
from d2go.utils.ema_state import EMAState
from detectron2.modeling import build_model
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
)
from pytorch_lightning.utilities import rank_zero_info
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
def _is_lightning_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this checkpoint to be a Lightning checkpoint. """
return _STATE_DICT_KEY in checkpoint
def _is_d2go_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this to be a D2Go checkpoint. """
d2_go_keys = [_OLD_STATE_DICT_KEY, "optimizer", "scheduler", "iteration"]
for key in d2_go_keys:
if key not in checkpoint:
return False
return True
def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
""" Converst a D2Go Checkpoint to Lightning in-place by renaming keys."""
prefix = "model" # based on DefaultTask.model.
old_keys = list(d2_checkpoint[_OLD_STATE_DICT_KEY])
for key in old_keys:
d2_checkpoint[_OLD_STATE_DICT_KEY][f"{prefix}.{key}"] = d2_checkpoint[
_OLD_STATE_DICT_KEY
][key]
del d2_checkpoint[_OLD_STATE_DICT_KEY][key]
for old, new in zip(
[_OLD_STATE_DICT_KEY, "iteration"], [_STATE_DICT_KEY, "global_step"]
):
d2_checkpoint[new] = d2_checkpoint[old]
del d2_checkpoint[old]
for old, new in zip(
["optimizer", "scheduler"], ["optimizer_states", "lr_schedulers"]
):
d2_checkpoint[new] = [d2_checkpoint[old]]
del d2_checkpoint[old]
d2_checkpoint["epoch"] = 0
class ModelTag(str, Enum):
DEFAULT = "default"
EMA = "ema"
class DefaultTask(pl.LightningModule):
def __init__(self, cfg: CfgNode):
super().__init__()
self.cfg = cfg
self.model = build_model(cfg)
self.storage = None
# evaluators for validation datasets, split by model tag(default, ema),
# in the order of DATASETS.TEST
self.dataset_evaluators = {ModelTag.DEFAULT: []}
self.save_hyperparameters()
self.eval_res = None
self.ema_state: Optional[EMAState] = None
if cfg.MODEL_EMA.ENABLED:
self.ema_state = EMAState(
decay=cfg.MODEL_EMA.DECAY,
device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE,
)
self.model_ema = deepcopy(self.model)
self.dataset_evaluators[ModelTag.EMA] = []
def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
def training_step(self, batch, batch_idx):
loss_dict = self.forward(batch)
losses = sum(loss_dict.values())
self.storage.step()
self.log_dict(loss_dict, prog_bar=True)
return losses
def test_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
self._evaluation_step(batch, batch_idx, dataloader_idx)
def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
self._evaluation_step(batch, batch_idx, dataloader_idx)
def _evaluation_step(self, batch, batch_idx: int, dataloader_idx: int) -> None:
if not isinstance(batch, List):
batch = [batch]
outputs = self.forward(batch)
self.dataset_evaluators[ModelTag.DEFAULT][dataloader_idx].process(
batch, outputs
)
if self.ema_state:
ema_outputs = self.model_ema(batch)
self.dataset_evaluators[ModelTag.EMA][dataloader_idx].process(
batch, ema_outputs
)
def _log_dataset_evaluation_results(self) -> None:
nested_res = {}
for tag, evaluators in self.dataset_evaluators.items():
res = {}
for idx, evaluator in enumerate(evaluators):
dataset_name = self.cfg.DATASETS.TEST[idx]
res[dataset_name] = evaluator.evaluate()
nested_res[tag.value] = res
self.eval_res = nested_res
flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res)
self.log_dict(flattened)
def test_epoch_end(self, _outputs) -> None:
self._evaluation_epoch_end()
def validation_epoch_end(self, _outputs) -> None:
self._evaluation_epoch_end()
def _evaluation_epoch_end(self) -> None:
self._log_dataset_evaluation_results()
self._reset_dataset_evaluators()
def configure_optimizers(
self,
) -> Tuple[List[torch.optim.Optimizer], List]:
optim = d2_build_optimizer(self.cfg, self.model)
lr_scheduler = d2_build_lr_scheduler(self.cfg, optim)
return [optim], [{"scheduler": lr_scheduler, "interval": "step"}]
def train_dataloader(self):
return Detectron2GoRunner.build_detection_train_loader(self.cfg)
def _reset_dataset_evaluators(self):
"""reset validation dataset evaluator to be run in EVAL_PERIOD steps"""
assert (
not self.trainer.distributed_backend
or self.trainer.distributed_backend.lower()
in [
"ddp",
"ddp_cpu",
]
), (
"Only DDP and DDP_CPU distributed backend are supported"
)
def _get_inference_dir_name(
base_dir, inference_type, dataset_name, model_tag: ModelTag
):
next_eval_iter = self.trainer.global_step + self.cfg.TEST.EVAL_PERIOD
if self.trainer.global_step == 0:
next_eval_iter -= 1
return os.path.join(
base_dir,
inference_type,
model_tag,
str(next_eval_iter),
dataset_name,
)
for tag, dataset_evaluators in self.dataset_evaluators.items():
dataset_evaluators.clear()
assert self.cfg.OUTPUT_DIR, "Expect output_dir to be specified in config"
for dataset_name in self.cfg.DATASETS.TEST:
# setup evaluator for each dataset
output_folder = _get_inference_dir_name(
self.cfg.OUTPUT_DIR, "inference", dataset_name, tag
)
evaluator = Detectron2GoRunner.get_evaluator(
self.cfg, dataset_name, output_folder=output_folder
)
evaluator.reset()
dataset_evaluators.append(evaluator)
# TODO: add visualization evaluator
def _evaluation_dataloader(self):
# TODO: Support subsample n images
assert len(self.cfg.DATASETS.TEST)
dataloaders = []
for dataset_name in self.cfg.DATASETS.TEST:
dataloaders.append(
Detectron2GoRunner.build_detection_test_loader(self.cfg, dataset_name)
)
self._reset_dataset_evaluators()
return dataloaders
def test_dataloader(self):
return self._evaluation_dataloader()
def val_dataloader(self):
return self._evaluation_dataloader()
def forward(self, input):
return self.model(input)
def on_pretrain_routine_end(self) -> None:
if self.cfg.MODEL_EMA.ENABLED:
if self.ema_state and self.ema_state.has_inited():
# ema_state could have been loaded from checkpoint
return
self.ema_state = EMAState.from_model(
self.model,
decay=self.cfg.MODEL_EMA.DECAY,
device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE,
)
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
if self.ema_state:
self.ema_state.update(self.model)
def on_test_epoch_start(self):
self._on_evaluation_epoch_start()
def on_validation_epoch_start(self):
self._on_evaluation_epoch_start()
def _on_evaluation_epoch_start(self):
if self.ema_state:
self.ema_state.apply_to(self.model_ema)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.ema_state:
checkpoint["model_ema"] = self.ema_state.state_dict()
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
"""
Called before model state is restored. Explicitly handles old model
states so we can resume training from D2Go checkpoints transparently.
Args:
checkpointed_state: The raw checkpoint state as returned by torch.load
or equivalent.
"""
# If this is a non-Lightning checkpoint, we need to convert it.
if not _is_lightning_checkpoint(checkpointed_state) and not _is_d2go_checkpoint(
checkpointed_state
):
raise ValueError(
f"Invalid checkpoint state with keys: {checkpointed_state.keys()}"
)
if not _is_lightning_checkpoint(checkpointed_state):
_convert_to_lightning(checkpointed_state)
if self.ema_state:
if "model_ema" not in checkpointed_state:
rank_zero_info(
"EMA is enabled but EMA state is not found in given checkpoint"
)
else:
self.ema_state = EMAState()
self.ema_state.load_state_dict(checkpointed_state["model_ema"])
if not self.ema_state.device:
# EMA state device not given, move to module device
self.ema_state.to(self.device)
class GeneralizedRCNNTask(DefaultTask):
@classmethod
def get_default_cfg(cls):
return GeneralizedRCNNRunner.get_default_cfg()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import logging
import os
import time
import detectron2.utils.comm as comm
import torch
from d2go.config import (
CfgNode as CN,
auto_scale_world_size,
reroute_config_path,
temp_defrost,
)
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import GeneralizedRCNNRunner, create_runner
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.logger import setup_logger
from detectron2.utils.serialize import PicklableWrapper
from fvcore.common.file_io import PathManager
from d2go.utils.helper import run_once
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
logger = logging.getLogger(__name__)
def basic_argument_parser(
distributed=True,
requires_config_file=True,
requires_output_dir=True,
):
""" Basic cli tool parser for Detectron2Go binaries """
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"--runner",
type=str,
default="d2go.runner.GeneralizedRCNNRunner",
help="Full class name, i.e. (package.)module.class",
)
parser.add_argument(
"--config-file",
help="path to config file",
default="",
required=requires_config_file,
metavar="FILE",
)
parser.add_argument(
"--output-dir",
help="When given, this will override the OUTPUT_DIR in the config-file",
required=requires_output_dir,
default=None,
type=str,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
if distributed:
parser.add_argument(
"--num-processes", type=int, default=1, help="number of gpus per machine"
)
parser.add_argument("--num-machines", type=int, default=1)
parser.add_argument(
"--machine-rank",
type=int,
default=0,
help="the rank of this machine (unique per machine)",
)
parser.add_argument(
"--dist-url", default="file:///tmp/d2go_dist_file_{}".format(time.time())
)
parser.add_argument("--dist-backend", type=str, default="NCCL")
if not requires_config_file:
# NOTE if not passing yaml file, user should explicitly set the
# following args, and use `opts` for non-common usecase.
parser.add_argument(
"--datasets",
type=str,
nargs="+",
required=True,
help="cfg.DATASETS.TEST",
)
parser.add_argument(
"--min_size",
type=int,
required=True,
help="cfg.INPUT.MIN_SIZE_TEST",
)
parser.add_argument(
"--max_size",
type=int,
required=True,
help="cfg.INPUT.MAX_SIZE_TEST",
)
return parser
return parser
def create_cfg_from_cli_args(args, default_cfg):
"""
Instead of loading from defaults.py, this binary only includes necessary
configs building from scratch, and overrides them from args. There're two
levels of config:
_C: the config system used by this binary, which is a sub-set of training
config, override by configurable_cfg. It can also be override by
args.opts for convinience.
configurable_cfg: common configs that user should explicitly specify
in the args.
"""
_C = CN()
_C.INPUT = default_cfg.INPUT
_C.DATASETS = default_cfg.DATASETS
_C.DATALOADER = default_cfg.DATALOADER
_C.TEST = default_cfg.TEST
if hasattr(default_cfg, "D2GO_DATA"):
_C.D2GO_DATA = default_cfg.D2GO_DATA
if hasattr(default_cfg, "TENSORBOARD"):
_C.TENSORBOARD = default_cfg.TENSORBOARD
# NOTE configs below might not be necessary, but must add to make code work
_C.MODEL = CN()
_C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE
_C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON
_C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON
_C.MODEL.LOAD_PROPOSALS = default_cfg.MODEL.LOAD_PROPOSALS
assert _C.MODEL.LOAD_PROPOSALS is False, "caffe2 model doesn't support"
_C.OUTPUT_DIR = args.output_dir
configurable_cfg = [
"DATASETS.TEST",
args.datasets,
"INPUT.MIN_SIZE_TEST",
args.min_size,
"INPUT.MAX_SIZE_TEST",
args.max_size,
]
cfg = _C.clone()
cfg.merge_from_list(configurable_cfg)
cfg.merge_from_list(args.opts)
return cfg
def prepare_for_launch(args):
"""
Load config, figure out working directory, create runner.
- when args.config_file is empty, returned cfg will be the default one
- returned output_dir will always be non empty, args.output_dir has higher
priority than cfg.OUTPUT_DIR.
"""
print(args)
runner = create_runner(args.runner)
cfg = runner.get_default_cfg()
if args.config_file:
with PathManager.open(reroute_config_path(args.config_file), "r") as f:
print("Loaded config file {}:\n{}".format(args.config_file, f.read()))
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
else:
cfg = create_cfg_from_cli_args(args, default_cfg=cfg)
cfg.freeze()
assert args.output_dir or args.config_file
output_dir = args.output_dir or cfg.OUTPUT_DIR
return cfg, output_dir, runner
def setup_after_launch(cfg, output_dir, runner):
"""
Set things up after entering DDP, including
- creating working directory
- setting up logger
- logging environment
- initializing runner
"""
create_dir_on_global_main_process(output_dir)
comm.synchronize()
setup_loggers(output_dir)
cfg.freeze()
if cfg.OUTPUT_DIR != output_dir:
with temp_defrost(cfg):
logger.warning(
"Override cfg.OUTPUT_DIR ({}) to be the same as output_dir {}".format(
cfg.OUTPUT_DIR, output_dir
)
)
cfg.OUTPUT_DIR = output_dir
logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg)
log_info(cfg, runner)
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
@run_once()
def setup_loggers(output_dir, color=None):
if not color:
color = get_launch_environment() == "local"
d2_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="detectron2",
abbrev_name="d2",
)
fvcore_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="fvcore",
)
d2go_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="d2go",
abbrev_name="d2go",
)
mobile_cv_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="mobile_cv",
abbrev_name="mobile_cv",
)
# NOTE: all above loggers have FileHandler pointing to the same file as d2_logger.
# Those files are opened upon creation, but it seems fine in 'a' mode.
# NOTE: the root logger might has been configured by other applications,
# since this already sub-top level, just don't propagate to root.
d2_logger.propagate = False
fvcore_logger.propagate = False
d2go_logger.propagate = False
mobile_cv_logger.propagate = False
def log_info(cfg, runner):
num_processes = get_num_processes_per_machine()
logger.info(
"Using {} processes per machine. Rank of current process: {}".format(
num_processes, comm.get_rank()
)
)
logger.info("Environment info:\n" + collect_env_info())
logger.info("Running with full config:\n{}".format(cfg))
logger.info("Running with runner: {}".format(runner))
def dump_cfg(cfg, path):
if comm.is_main_process():
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
logger.info("Full config saved to {}".format(path))
def create_dir_on_local_main_process(dir):
if get_local_rank() == 0 and dir:
PathManager.mkdirs(dir)
def create_dir_on_global_main_process(dir):
if comm.get_rank() == 0 and dir:
PathManager.mkdirs(dir)
def initialize_runner(runner, cfg):
runner = runner or GeneralizedRCNNRunner()
runner._initialize(cfg)
return runner
def caffe2_global_init(logging_print_net_summary=0, num_threads=None):
if num_threads is None:
if get_num_processes_per_machine() > 1:
# by default use single thread when DDP with multiple processes
num_threads = 1
else:
# GlobalInit will clean PyTorch's num_threads and set it to 1,
# thus keep PyTorch's default value to make it truly default.
num_threads = torch.get_num_threads()
if not get_local_rank() == 0:
logging_print_net_summary = 0 # only enable for local main process
from caffe2.python import workspace
workspace.GlobalInit(
[
"caffe2",
"--caffe2_log_level=2",
"--caffe2_logging_print_net_summary={}".format(logging_print_net_summary),
"--caffe2_omp_num_threads={}".format(num_threads),
"--caffe2_mkl_num_threads={}".format(num_threads),
]
)
logger.info("Using {} threads after GlobalInit".format(torch.get_num_threads()))
def post_mortem_if_fail_for_main(main_func):
def new_main_func(cfg, output_dir, *args, **kwargs):
pdb_ = (
MultiprocessingPdb(FolderLock(output_dir))
if comm.get_world_size() > 1
else None # fallback to use normal pdb for single process
)
return post_mortem_if_fail(pdb_)(main_func)(cfg, output_dir, *args, **kwargs)
return PicklableWrapper(new_main_func)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
import detectron2.utils.comm as comm
import torch
from d2go.utils.visualization import VisualizerWrapper
from fvcore.common.file_io import PathManager
logger = logging.getLogger(__name__)
def get_rel_loss_checker(rel_thres=1.0):
def _loss_delta_exceeds_thresh(prev_loss, loss):
if prev_loss is None:
return True
prev_sum = sum(prev_loss.values())
cur_sum = sum(loss.values())
if prev_sum <= 0:
return True
if (cur_sum - prev_sum) / prev_sum >= rel_thres:
return False
return True
return _loss_delta_exceeds_thresh
class TrainImageWriter(object):
def __init__(self, cfg, tbx_writer, max_count=5):
""" max_count: max number of data written to tensorboard, additional call
will be ignored
"""
self.visualizer = VisualizerWrapper(cfg)
self.writer = tbx_writer
self.max_count = max_count
self.counter = 0
def __call__(self, all_data):
if self.max_count > 0 and self.counter >= self.max_count:
return
data = all_data["data"]
step = all_data["step"]
for idx, cur_data in enumerate(data):
name = f"train_abnormal_losses/{step}/img_{idx}/{cur_data['file_name']}"
vis_img = self.visualizer.visualize_train_input(cur_data)
self.writer._writer.add_image(name, vis_img, step, dataformats="HWC")
logger.warning(
"Train images with bad losses written to tensorboard 'train_abnormal_losses'"
)
self.counter += 1
class FileWriter(object):
def __init__(self, output_dir, max_count=5):
""" max_count: max number of data written to tensorboard, additional call
will be ignored
"""
self.output_dir = output_dir
self.max_count = max_count
self.counter = 0
def __call__(self, all_data):
if self.max_count > 0 and self.counter >= self.max_count:
return
output_dir = self.output_dir
step = all_data["step"]
losses = all_data["losses"]
file_name = f"train_abnormal_losses_{step}_{comm.get_rank()}.pth"
out_file = os.path.join(output_dir, file_name)
with PathManager.open(out_file, "wb") as fp:
torch.save(all_data, fp)
logger.warning(
f"Iteration {step} has bad losses {losses}. "
f"all information saved to {out_file}."
)
self.counter += 1
def get_writers(cfg, tbx_writer):
writers = [TrainImageWriter(cfg, tbx_writer), FileWriter(cfg.OUTPUT_DIR)]
return writers
class AbnormalLossChecker(object):
def __init__(self, start_iter, writers, valid_loss_checker=None):
self.valid_loss_checker = valid_loss_checker or get_rel_loss_checker()
self.writers = writers or []
assert isinstance(self.writers, list)
self.prev_index = start_iter
self.prev_loss = None
def check_step(self, losses, data=None, model=None):
with torch.no_grad():
is_valid = self.valid_loss_checker(self.prev_loss, losses)
if not is_valid:
self._write_invalid_info(losses, data, model)
self.prev_index += 1
self.prev_loss = losses
return is_valid
def _write_invalid_info(self, losses, data, model):
all_info = {
"losses": losses,
"data": data,
"model": model.module if hasattr(model, "module") else model,
"prev_loss": self.prev_loss,
"step": self.prev_index + 1,
}
for writer in self.writers:
writer(all_info)
class AbnormalLossCheckerWrapper(torch.nn.Module):
def __init__(self, model, checker):
super().__init__()
self.checker = checker
self.model = model
self.training = model.training
def forward(self, x):
losses = self.model(x)
self.checker.check_step(losses, data=x, model=self.model)
return losses
#!/usr/bin/env python3
import itertools
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
class EMAState(object):
"""Stores Exponential Moving Average state for a model.
Args:
decay: EMA decay factor, should be in [0, 1]. A decay of 0 corresponds to
always using the latest value (no EMA) and a decay of 1 corresponds to
not updating weights after initialization. Default to 0.999.
device: If not None, move model EMA state to device.
"""
def __init__(self, decay: float = 0.999, device: Optional[str] = None):
if decay < 0 or decay > 1.0:
raise ValueError(f"Decay should be in [0, 1], {decay} was given.")
self.decay: float = decay
self.state: Dict[str, Any] = {}
self.device: Optional[str] = device
@classmethod
def from_model(
cls,
model: nn.Module,
decay: float = 0.999,
device: Optional[str] = None,
) -> "EMAState":
""" Constructs model state from the model and move to device if given."""
ret = cls(decay, device)
ret.load_from(model)
return ret
def load_from(self, model: nn.Module) -> None:
""" Load state from the model. """
self.state.clear()
for name, val in self._get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(self.device) if self.device else val
def has_inited(self) -> bool:
return len(self.state) > 0
def apply_to(self, model: nn.Module) -> None:
""" Apply EMA state to the model. """
with torch.no_grad():
for name, val in self._get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} does not exist, available names are {self.state.keys()}"
val.copy_(self.state[name])
def state_dict(self) -> Dict[str, Any]:
return self.state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.state.clear()
for name, val in state_dict.items():
self.state[name] = val.to(self.device) if self.device else val
def to(self, device: torch.device) -> None:
""" moves EMA state to device. """
for name, val in self.state.items():
self.state[name] = val.to(device)
def _get_model_state_iterator(self, model: nn.Module):
param_iter = model.named_parameters()
# pyre-fixme[16]: `nn.Module` has no attribute `named_buffers`.
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def update(self, model: nn.Module) -> None:
with torch.no_grad():
for name, val in self._get_model_state_iterator(model):
ema_val = self.state[name]
if self.device:
val = val.to(self.device)
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment