Commit 3144257c authored by mashun1's avatar mashun1
Browse files

catvton

parents
_BASE_: "../COCO-Detection/retinanet_R_50_FPN_3x.yaml"
MODEL:
WEIGHTS: "detectron2://COCO-Detection/retinanet_R_50_FPN_3x/190397829/model_final_5bd44e.pkl"
DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 44.45, 0.02]]
_BASE_: "../COCO-Detection/retinanet_R_50_FPN_1x.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ("coco_2017_val_100",)
TEST: ("coco_2017_val_100",)
SOLVER:
BASE_LR: 0.005
STEPS: (30,)
MAX_ITER: 40
IMS_PER_BATCH: 4
DATALOADER:
NUM_WORKERS: 2
_BASE_: "../COCO-Detection/rpn_R_50_FPN_1x.yaml"
MODEL:
WEIGHTS: "detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/model_final_02ce48.pkl"
DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["box_proposals", "AR@1000", 58.16, 0.02]]
_BASE_: "../COCO-Detection/rpn_R_50_FPN_1x.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ("coco_2017_val_100",)
TEST: ("coco_2017_val_100",)
SOLVER:
STEPS: (30,)
MAX_ITER: 40
BASE_LR: 0.005
IMS_PER_BATCH: 4
DATALOADER:
NUM_WORKERS: 2
_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
WEIGHTS: "detectron2://semantic_R_50_FPN_1x/111802073/model_final_c18079783c55a94968edc28b7101c5f0.pkl"
RESNETS:
DEPTH: 50
DATASETS:
TEST: ("coco_2017_val_100_panoptic_stuffonly",)
TEST:
EXPECTED_RESULTS: [["sem_seg", "mIoU", 39.53, 0.02], ["sem_seg", "mACC", 51.50, 0.02]]
_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
RESNETS:
DEPTH: 50
DATASETS:
TRAIN: ("coco_2017_val_100_panoptic_stuffonly",)
TEST: ("coco_2017_val_100_panoptic_stuffonly",)
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
SOLVER:
BASE_LR: 0.005
STEPS: (30,)
MAX_ITER: 40
IMS_PER_BATCH: 4
DATALOADER:
NUM_WORKERS: 2
_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
RESNETS:
DEPTH: 50
DATASETS:
TRAIN: ("coco_2017_val_panoptic_stuffonly",)
TEST: ("coco_2017_val_panoptic_stuffonly",)
SOLVER:
BASE_LR: 0.01
WARMUP_FACTOR: 0.001
WARMUP_ITERS: 300
STEPS: (5500,)
MAX_ITER: 7000
TEST:
EXPECTED_RESULTS: [["sem_seg", "mIoU", 76.51, 1.0], ["sem_seg", "mACC", 83.25, 1.0]]
INPUT:
# no scale augmentation
MIN_SIZE_TRAIN: (800, )
# Copyright (c) Facebook, Inc. and its affiliates.
import os
from typing import Optional
import pkg_resources
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate
from detectron2.modeling import build_model
class _ModelZooUrls:
"""
Mapping from names to officially released Detectron2 pre-trained models.
"""
S3_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
# format: {config_path.yaml} -> model_id/model_final_{commit}.pkl
CONFIG_PATH_TO_URL_SUFFIX = {
# COCO Detection with Faster R-CNN
"COCO-Detection/faster_rcnn_R_50_C4_1x": "137257644/model_final_721ade.pkl",
"COCO-Detection/faster_rcnn_R_50_DC5_1x": "137847829/model_final_51d356.pkl",
"COCO-Detection/faster_rcnn_R_50_FPN_1x": "137257794/model_final_b275ba.pkl",
"COCO-Detection/faster_rcnn_R_50_C4_3x": "137849393/model_final_f97cb7.pkl",
"COCO-Detection/faster_rcnn_R_50_DC5_3x": "137849425/model_final_68d202.pkl",
"COCO-Detection/faster_rcnn_R_50_FPN_3x": "137849458/model_final_280758.pkl",
"COCO-Detection/faster_rcnn_R_101_C4_3x": "138204752/model_final_298dad.pkl",
"COCO-Detection/faster_rcnn_R_101_DC5_3x": "138204841/model_final_3e0943.pkl",
"COCO-Detection/faster_rcnn_R_101_FPN_3x": "137851257/model_final_f6e8b1.pkl",
"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x": "139173657/model_final_68b088.pkl",
# COCO Detection with RetinaNet
"COCO-Detection/retinanet_R_50_FPN_1x": "190397773/model_final_bfca0b.pkl",
"COCO-Detection/retinanet_R_50_FPN_3x": "190397829/model_final_5bd44e.pkl",
"COCO-Detection/retinanet_R_101_FPN_3x": "190397697/model_final_971ab9.pkl",
# COCO Detection with RPN and Fast R-CNN
"COCO-Detection/rpn_R_50_C4_1x": "137258005/model_final_450694.pkl",
"COCO-Detection/rpn_R_50_FPN_1x": "137258492/model_final_02ce48.pkl",
"COCO-Detection/fast_rcnn_R_50_FPN_1x": "137635226/model_final_e5f7ce.pkl",
# COCO Instance Segmentation Baselines with Mask R-CNN
"COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x": "137259246/model_final_9243eb.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x": "137260150/model_final_4f86c3.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "137260431/model_final_a54504.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x": "137849525/model_final_4ce675.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x": "137849551/model_final_84107b.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x": "137849600/model_final_f10217.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x": "138363239/model_final_a2914c.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x": "138363294/model_final_0464b7.pkl",
"COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x": "138205316/model_final_a3ec72.pkl",
"COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x": "139653917/model_final_2d9806.pkl", # noqa
# New baselines using Large-Scale Jitter and Longer Training Schedule
"new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ": "42047764/model_final_bb69de.pkl",
"new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ": "42047638/model_final_89a8d3.pkl",
"new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ": "42019571/model_final_14d201.pkl",
"new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ": "42025812/model_final_4f7b58.pkl",
"new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ": "42131867/model_final_0bb7ae.pkl",
"new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ": "42073830/model_final_f96b26.pkl",
"new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ": "42047771/model_final_b7fbab.pkl", # noqa
"new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ": "42132721/model_final_5d87c1.pkl", # noqa
"new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ": "42025447/model_final_f1362d.pkl", # noqa
"new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ": "42047784/model_final_6ba57e.pkl", # noqa
"new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ": "42047642/model_final_27b9c1.pkl", # noqa
"new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ": "42045954/model_final_ef3a80.pkl", # noqa
# COCO Person Keypoint Detection Baselines with Keypoint R-CNN
"COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x": "137261548/model_final_04e291.pkl",
"COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x": "137849621/model_final_a6e10b.pkl",
"COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x": "138363331/model_final_997cc7.pkl",
"COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x": "139686956/model_final_5ad38f.pkl",
# COCO Panoptic Segmentation Baselines with Panoptic FPN
"COCO-PanopticSegmentation/panoptic_fpn_R_50_1x": "139514544/model_final_dbfeb4.pkl",
"COCO-PanopticSegmentation/panoptic_fpn_R_50_3x": "139514569/model_final_c10459.pkl",
"COCO-PanopticSegmentation/panoptic_fpn_R_101_3x": "139514519/model_final_cafdb1.pkl",
# LVIS Instance Segmentation Baselines with Mask R-CNN
"LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "144219072/model_final_571f7c.pkl", # noqa
"LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x": "144219035/model_final_824ab5.pkl", # noqa
"LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x": "144219108/model_final_5e3439.pkl", # noqa
# Cityscapes & Pascal VOC Baselines
"Cityscapes/mask_rcnn_R_50_FPN": "142423278/model_final_af9cf5.pkl",
"PascalVOC-Detection/faster_rcnn_R_50_C4": "142202221/model_final_b1acc2.pkl",
# Other Settings
"Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5": "138602867/model_final_65c703.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5": "144998336/model_final_821d0b.pkl",
"Misc/cascade_mask_rcnn_R_50_FPN_1x": "138602847/model_final_e9d89b.pkl",
"Misc/cascade_mask_rcnn_R_50_FPN_3x": "144998488/model_final_480dd8.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_syncbn": "169527823/model_final_3b3c51.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_gn": "138602888/model_final_dc5d9e.pkl",
"Misc/scratch_mask_rcnn_R_50_FPN_3x_gn": "138602908/model_final_01ca85.pkl",
"Misc/scratch_mask_rcnn_R_50_FPN_9x_gn": "183808979/model_final_da7b4c.pkl",
"Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn": "184226666/model_final_5ce33e.pkl",
"Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x": "139797668/model_final_be35db.pkl",
"Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv": "18131413/model_0039999_e76410.pkl", # noqa
# D1 Comparisons
"Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x": "137781054/model_final_7ab50c.pkl", # noqa
"Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x": "137781281/model_final_62ca52.pkl", # noqa
"Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x": "137781195/model_final_cce136.pkl",
}
@staticmethod
def query(config_path: str) -> Optional[str]:
"""
Args:
config_path: relative config filename
"""
name = config_path.replace(".yaml", "").replace(".py", "")
if name in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX:
suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[name]
return _ModelZooUrls.S3_PREFIX + name + "/" + suffix
return None
def get_checkpoint_url(config_path):
"""
Returns the URL to the model trained using the given config
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: a URL to the model
"""
url = _ModelZooUrls.query(config_path)
if url is None:
raise RuntimeError("Pretrained model for {} is not available!".format(config_path))
return url
def get_config_file(config_path):
"""
Returns path to a builtin config file.
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: the real path to the config file.
"""
cfg_file = pkg_resources.resource_filename(
"detectron2.model_zoo", os.path.join("configs", config_path)
)
if not os.path.exists(cfg_file):
raise RuntimeError("{} not available in Model Zoo!".format(config_path))
return cfg_file
def get_config(config_path, trained: bool = False):
"""
Returns a config object for a model in model zoo.
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): If True, will set ``MODEL.WEIGHTS`` to trained model zoo weights.
If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used
instead; this will typically (though not always) initialize a subset of weights using
an ImageNet pre-trained model, while randomly initializing the other weights.
Returns:
CfgNode or omegaconf.DictConfig: a config object
"""
cfg_file = get_config_file(config_path)
if cfg_file.endswith(".yaml"):
cfg = get_cfg()
cfg.merge_from_file(cfg_file)
if trained:
cfg.MODEL.WEIGHTS = get_checkpoint_url(config_path)
return cfg
elif cfg_file.endswith(".py"):
cfg = LazyConfig.load(cfg_file)
if trained:
url = get_checkpoint_url(config_path)
if "train" in cfg and "init_checkpoint" in cfg.train:
cfg.train.init_checkpoint = url
else:
raise NotImplementedError
return cfg
def get(config_path, trained: bool = False, device: Optional[str] = None):
"""
Get a model specified by relative path under Detectron2's official ``configs/`` directory.
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): see :func:`get_config`.
device (str or None): overwrite the device in config, if given.
Returns:
nn.Module: a detectron2 model. Will be in training mode.
Example:
::
from detectron2 import model_zoo
model = model_zoo.get("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml", trained=True)
"""
cfg = get_config(config_path, trained)
if device is None and not torch.cuda.is_available():
device = "cpu"
if device is not None and isinstance(cfg, CfgNode):
cfg.MODEL.DEVICE = device
if isinstance(cfg, CfgNode):
model = build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
else:
model = instantiate(cfg.model)
if device is not None:
model = model.to(device)
if "train" in cfg and "init_checkpoint" in cfg.train:
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
return model
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.layers import ShapeSpec
from .anchor_generator import build_anchor_generator, ANCHOR_GENERATOR_REGISTRY
from .backbone import (
BACKBONE_REGISTRY,
FPN,
Backbone,
ResNet,
ResNetBlockBase,
build_backbone,
build_resnet_backbone,
make_stage,
ViT,
SimpleFeaturePyramid,
get_vit_lr_decay_rate,
MViT,
SwinTransformer,
)
from .meta_arch import (
META_ARCH_REGISTRY,
SEM_SEG_HEADS_REGISTRY,
GeneralizedRCNN,
PanopticFPN,
ProposalNetwork,
RetinaNet,
SemanticSegmentor,
build_model,
build_sem_seg_head,
FCOS,
)
from .postprocessing import detector_postprocess
from .proposal_generator import (
PROPOSAL_GENERATOR_REGISTRY,
build_proposal_generator,
RPN_HEAD_REGISTRY,
build_rpn_head,
)
from .roi_heads import (
ROI_BOX_HEAD_REGISTRY,
ROI_HEADS_REGISTRY,
ROI_KEYPOINT_HEAD_REGISTRY,
ROI_MASK_HEAD_REGISTRY,
ROIHeads,
StandardROIHeads,
BaseMaskRCNNHead,
BaseKeypointRCNNHead,
FastRCNNOutputLayers,
build_box_head,
build_keypoint_head,
build_mask_head,
build_roi_heads,
)
from .test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
from .mmdet_wrapper import MMDetBackbone, MMDetDetector
_EXCLUDE = {"ShapeSpec"}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
from detectron2.utils.env import fixup_module_metadata
fixup_module_metadata(__name__, globals(), __all__)
del fixup_module_metadata
# Copyright (c) Facebook, Inc. and its affiliates.
import collections
import math
from typing import List
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, move_device_like
from detectron2.structures import Boxes, RotatedBoxes
from detectron2.utils.registry import Registry
ANCHOR_GENERATOR_REGISTRY = Registry("ANCHOR_GENERATOR")
ANCHOR_GENERATOR_REGISTRY.__doc__ = """
Registry for modules that creates object detection anchors for feature maps.
The registered object will be called with `obj(cfg, input_shape)`.
"""
class BufferList(nn.Module):
"""
Similar to nn.ParameterList, but for buffers
"""
def __init__(self, buffers):
super().__init__()
for i, buffer in enumerate(buffers):
# Use non-persistent buffer so the values are not saved in checkpoint
self.register_buffer(str(i), buffer, persistent=False)
def __len__(self):
return len(self._buffers)
def __iter__(self):
return iter(self._buffers.values())
def _create_grid_offsets(
size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor
):
grid_height, grid_width = size
shifts_x = move_device_like(
torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32),
target_device_tensor,
)
shifts_y = move_device_like(
torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32),
target_device_tensor,
)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
return shift_x, shift_y
def _broadcast_params(params, num_features, name):
"""
If one size (or aspect ratio) is specified and there are multiple feature
maps, we "broadcast" anchors of that single size (or aspect ratio)
over all feature maps.
If params is list[float], or list[list[float]] with len(params) == 1, repeat
it num_features time.
Returns:
list[list[float]]: param for each feature
"""
assert isinstance(
params, collections.abc.Sequence
), f"{name} in anchor generator has to be a list! Got {params}."
assert len(params), f"{name} in anchor generator cannot be empty!"
if not isinstance(params[0], collections.abc.Sequence): # params is list[float]
return [params] * num_features
if len(params) == 1:
return list(params) * num_features
assert len(params) == num_features, (
f"Got {name} of length {len(params)} in anchor generator, "
f"but the number of input features is {num_features}!"
)
return params
@ANCHOR_GENERATOR_REGISTRY.register()
class DefaultAnchorGenerator(nn.Module):
"""
Compute anchors in the standard ways described in
"Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks".
"""
box_dim: torch.jit.Final[int] = 4
"""
the dimension of each anchor box.
"""
@configurable
def __init__(self, *, sizes, aspect_ratios, strides, offset=0.5):
"""
This interface is experimental.
Args:
sizes (list[list[float]] or list[float]):
If ``sizes`` is list[list[float]], ``sizes[i]`` is the list of anchor sizes
(i.e. sqrt of anchor area) to use for the i-th feature map.
If ``sizes`` is list[float], ``sizes`` is used for all feature maps.
Anchor sizes are given in absolute lengths in units of
the input image; they do not dynamically scale if the input image size changes.
aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
(i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
strides (list[int]): stride of each input feature.
offset (float): Relative offset between the center of the first anchor and the top-left
corner of the image. Value has to be in [0, 1).
Recommend to use 0.5, which means half stride.
"""
super().__init__()
self.strides = strides
self.num_features = len(self.strides)
sizes = _broadcast_params(sizes, self.num_features, "sizes")
aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios)
self.offset = offset
assert 0.0 <= self.offset < 1.0, self.offset
@classmethod
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
return {
"sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
"aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
"strides": [x.stride for x in input_shape],
"offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
}
def _calculate_anchors(self, sizes, aspect_ratios):
cell_anchors = [
self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)
]
return BufferList(cell_anchors)
@property
@torch.jit.unused
def num_cell_anchors(self):
"""
Alias of `num_anchors`.
"""
return self.num_anchors
@property
@torch.jit.unused
def num_anchors(self):
"""
Returns:
list[int]: Each int is the number of anchors at every pixel
location, on that feature map.
For example, if at every pixel we use anchors of 3 aspect
ratios and 5 sizes, the number of anchors is 15.
(See also ANCHOR_GENERATOR.SIZES and ANCHOR_GENERATOR.ASPECT_RATIOS in config)
In standard RPN models, `num_anchors` on every feature map is the same.
"""
return [len(cell_anchors) for cell_anchors in self.cell_anchors]
def _grid_anchors(self, grid_sizes: List[List[int]]):
"""
Returns:
list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4
"""
anchors = []
# buffers() not supported by torchscript. use named_buffers() instead
buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()]
for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
return anchors
def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
"""
Generate a tensor storing canonical anchor boxes, which are all anchor
boxes of different sizes and aspect_ratios centered at (0, 0).
We can later build the set of anchors for a full feature map by
shifting and tiling these tensors (see `meth:_grid_anchors`).
Args:
sizes (tuple[float]):
aspect_ratios (tuple[float]]):
Returns:
Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes
in XYXY format.
"""
# This is different from the anchor generator defined in the original Faster R-CNN
# code or Detectron. They yield the same AP, however the old version defines cell
# anchors in a less natural way with a shift relative to the feature grid and
# quantization that results in slightly different sizes for different aspect ratios.
# See also https://github.com/facebookresearch/Detectron/issues/227
anchors = []
for size in sizes:
area = size**2.0
for aspect_ratio in aspect_ratios:
# s * s = w * h
# a = h / w
# ... some algebra ...
# w = sqrt(s * s / a)
# h = a * w
w = math.sqrt(area / aspect_ratio)
h = aspect_ratio * w
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
anchors.append([x0, y0, x1, y1])
return torch.tensor(anchors)
def forward(self, features: List[torch.Tensor]):
"""
Args:
features (list[Tensor]): list of backbone feature maps on which to generate anchors.
Returns:
list[Boxes]: a list of Boxes containing all the anchors for each feature map
(i.e. the cell anchors repeated over all locations in the feature map).
The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
where Hi, Wi are resolution of the feature map divided by anchor stride.
"""
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) # pyre-ignore
return [Boxes(x) for x in anchors_over_all_feature_maps]
@ANCHOR_GENERATOR_REGISTRY.register()
class RotatedAnchorGenerator(nn.Module):
"""
Compute rotated anchors used by Rotated RPN (RRPN), described in
"Arbitrary-Oriented Scene Text Detection via Rotation Proposals".
"""
box_dim: int = 5
"""
the dimension of each anchor box.
"""
@configurable
def __init__(self, *, sizes, aspect_ratios, strides, angles, offset=0.5):
"""
This interface is experimental.
Args:
sizes (list[list[float]] or list[float]):
If sizes is list[list[float]], sizes[i] is the list of anchor sizes
(i.e. sqrt of anchor area) to use for the i-th feature map.
If sizes is list[float], the sizes are used for all feature maps.
Anchor sizes are given in absolute lengths in units of
the input image; they do not dynamically scale if the input image size changes.
aspect_ratios (list[list[float]] or list[float]): list of aspect ratios
(i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies.
strides (list[int]): stride of each input feature.
angles (list[list[float]] or list[float]): list of angles (in degrees CCW)
to use for anchors. Same "broadcast" rule for `sizes` applies.
offset (float): Relative offset between the center of the first anchor and the top-left
corner of the image. Value has to be in [0, 1).
Recommend to use 0.5, which means half stride.
"""
super().__init__()
self.strides = strides
self.num_features = len(self.strides)
sizes = _broadcast_params(sizes, self.num_features, "sizes")
aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios")
angles = _broadcast_params(angles, self.num_features, "angles")
self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios, angles)
self.offset = offset
assert 0.0 <= self.offset < 1.0, self.offset
@classmethod
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
return {
"sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
"aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS,
"strides": [x.stride for x in input_shape],
"offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
"angles": cfg.MODEL.ANCHOR_GENERATOR.ANGLES,
}
def _calculate_anchors(self, sizes, aspect_ratios, angles):
cell_anchors = [
self.generate_cell_anchors(size, aspect_ratio, angle).float()
for size, aspect_ratio, angle in zip(sizes, aspect_ratios, angles)
]
return BufferList(cell_anchors)
@property
def num_cell_anchors(self):
"""
Alias of `num_anchors`.
"""
return self.num_anchors
@property
def num_anchors(self):
"""
Returns:
list[int]: Each int is the number of anchors at every pixel
location, on that feature map.
For example, if at every pixel we use anchors of 3 aspect
ratios, 2 sizes and 5 angles, the number of anchors is 30.
(See also ANCHOR_GENERATOR.SIZES, ANCHOR_GENERATOR.ASPECT_RATIOS
and ANCHOR_GENERATOR.ANGLES in config)
In standard RRPN models, `num_anchors` on every feature map is the same.
"""
return [len(cell_anchors) for cell_anchors in self.cell_anchors]
def _grid_anchors(self, grid_sizes: List[List[int]]):
anchors = []
for size, stride, base_anchors in zip(
grid_sizes,
self.strides,
self.cell_anchors._buffers.values(),
):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
zeros = torch.zeros_like(shift_x)
shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1)
anchors.append((shifts.view(-1, 1, 5) + base_anchors.view(1, -1, 5)).reshape(-1, 5))
return anchors
def generate_cell_anchors(
self,
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.5, 1, 2),
angles=(-90, -60, -30, 0, 30, 60, 90),
):
"""
Generate a tensor storing canonical anchor boxes, which are all anchor
boxes of different sizes, aspect_ratios, angles centered at (0, 0).
We can later build the set of anchors for a full feature map by
shifting and tiling these tensors (see `meth:_grid_anchors`).
Args:
sizes (tuple[float]):
aspect_ratios (tuple[float]]):
angles (tuple[float]]):
Returns:
Tensor of shape (len(sizes) * len(aspect_ratios) * len(angles), 5)
storing anchor boxes in (x_ctr, y_ctr, w, h, angle) format.
"""
anchors = []
for size in sizes:
area = size**2.0
for aspect_ratio in aspect_ratios:
# s * s = w * h
# a = h / w
# ... some algebra ...
# w = sqrt(s * s / a)
# h = a * w
w = math.sqrt(area / aspect_ratio)
h = aspect_ratio * w
anchors.extend([0, 0, w, h, a] for a in angles)
return torch.tensor(anchors)
def forward(self, features):
"""
Args:
features (list[Tensor]): list of backbone feature maps on which to generate anchors.
Returns:
list[RotatedBoxes]: a list of Boxes containing all the anchors for each feature map
(i.e. the cell anchors repeated over all locations in the feature map).
The number of anchors of each feature map is Hi x Wi x num_cell_anchors,
where Hi, Wi are resolution of the feature map divided by anchor stride.
"""
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
return [RotatedBoxes(x) for x in anchors_over_all_feature_maps]
def build_anchor_generator(cfg, input_shape):
"""
Built an anchor generator from `cfg.MODEL.ANCHOR_GENERATOR.NAME`.
"""
anchor_generator = cfg.MODEL.ANCHOR_GENERATOR.NAME
return ANCHOR_GENERATOR_REGISTRY.get(anchor_generator)(cfg, input_shape)
# Copyright (c) Facebook, Inc. and its affiliates.
from .build import build_backbone, BACKBONE_REGISTRY # noqa F401 isort:skip
from .backbone import Backbone
from .fpn import FPN
from .regnet import RegNet
from .resnet import (
BasicStem,
ResNet,
ResNetBlockBase,
build_resnet_backbone,
make_stage,
BottleneckBlock,
)
from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate
from .mvit import MViT
from .swin import SwinTransformer
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# TODO can expose more resnet blocks after careful consideration
# Copyright (c) Facebook, Inc. and its affiliates.
from abc import ABCMeta, abstractmethod
from typing import Dict
import torch.nn as nn
from detectron2.layers import ShapeSpec
__all__ = ["Backbone"]
class Backbone(nn.Module, metaclass=ABCMeta):
"""
Abstract base class for network backbones.
"""
def __init__(self):
"""
The `__init__` method of any subclass can specify its own set of arguments.
"""
super().__init__()
@abstractmethod
def forward(self):
"""
Subclasses must override this method, but adhere to the same return type.
Returns:
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
"""
pass
@property
def size_divisibility(self) -> int:
"""
Some backbones require the input height and width to be divisible by a
specific integer. This is typically true for encoder / decoder type networks
with lateral connection (e.g., FPN) for which feature maps need to match
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
input size divisibility is required.
"""
return 0
@property
def padding_constraints(self) -> Dict[str, int]:
"""
This property is a generalization of size_divisibility. Some backbones and training
recipes require specific padding constraints, such as enforcing divisibility by a specific
integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
in :paper:vitdet). `padding_constraints` contains these optional items like:
{
"size_divisibility": int,
"square_size": int,
# Future options are possible
}
`size_divisibility` will read from here if presented and `square_size` indicates the
square padding size if `square_size` > 0.
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
could be generalized as TypedDict (Python 3.8+) to support more types in the future.
"""
return {}
def output_shape(self):
"""
Returns:
dict[str->ShapeSpec]
"""
# this is a backward-compatible default
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.layers import ShapeSpec
from detectron2.utils.registry import Registry
from .backbone import Backbone
BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
Registered object must return instance of :class:`Backbone`.
"""
def build_backbone(cfg, input_shape=None):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
if input_shape is None:
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
backbone_name = cfg.MODEL.BACKBONE.NAME
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)
assert isinstance(backbone, Backbone)
return backbone
# Copyright (c) Facebook, Inc. and its affiliates.
import math
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from .backbone import Backbone
from .build import BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"]
class FPN(Backbone):
"""
This module implements :paper:`FPN`.
It creates pyramid features built on top of some input feature maps.
"""
_fuse_type: torch.jit.Final[str]
def __init__(
self,
bottom_up,
in_features,
out_channels,
norm="",
top_block=None,
fuse_type="sum",
square_pad=0,
):
"""
Args:
bottom_up (Backbone): module representing the bottom up subnetwork.
Must be a subclass of :class:`Backbone`. The multi-scale feature
maps generated by the bottom up network, and listed in `in_features`,
are used to generate FPN levels.
in_features (list[str]): names of the input feature maps coming
from the backbone to which FPN is attached. For example, if the
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
of these may be used; order must be from high to low resolution.
out_channels (int): number of channels in the output feature maps.
norm (str): the normalization to use.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra FPN levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
fuse_type (str): types for fusing the top down features and the lateral
ones. It can be "sum" (default), which sums up element-wise; or "avg",
which takes the element-wise mean of the two.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(FPN, self).__init__()
assert isinstance(bottom_up, Backbone)
assert in_features, in_features
# Feature map strides and channels from the bottom up network (e.g. ResNet)
input_shapes = bottom_up.output_shape()
strides = [input_shapes[f].stride for f in in_features]
in_channels_per_feature = [input_shapes[f].channels for f in in_features]
_assert_strides_are_log2_contiguous(strides)
lateral_convs = []
output_convs = []
use_bias = norm == ""
for idx, in_channels in enumerate(in_channels_per_feature):
lateral_norm = get_norm(norm, out_channels)
output_norm = get_norm(norm, out_channels)
lateral_conv = Conv2d(
in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm
)
output_conv = Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
)
weight_init.c2_xavier_fill(lateral_conv)
weight_init.c2_xavier_fill(output_conv)
stage = int(math.log2(strides[idx]))
self.add_module("fpn_lateral{}".format(stage), lateral_conv)
self.add_module("fpn_output{}".format(stage), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
self.top_block = top_block
self.in_features = tuple(in_features)
self.bottom_up = bottom_up
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
self._square_pad = square_pad
assert fuse_type in {"avg", "sum"}
self._fuse_type = fuse_type
@property
def size_divisibility(self):
return self._size_divisibility
@property
def padding_constraints(self):
return {"square_size": self._square_pad}
def forward(self, x):
"""
Args:
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
feature map tensor for each feature level in high to low resolution order.
Returns:
dict[str->Tensor]:
mapping from feature map name to FPN feature map tensor
in high to low resolution order. Returned feature names follow the FPN
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
["p2", "p3", ..., "p6"].
"""
bottom_up_features = self.bottom_up(x)
results = []
prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]])
results.append(self.output_convs[0](prev_features))
# Reverse feature maps into top-down order (from low to high resolution)
for idx, (lateral_conv, output_conv) in enumerate(
zip(self.lateral_convs, self.output_convs)
):
# Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
# Therefore we loop over all modules but skip the first one
if idx > 0:
features = self.in_features[-idx - 1]
features = bottom_up_features[features]
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
lateral_features = lateral_conv(features)
prev_features = lateral_features + top_down_features
if self._fuse_type == "avg":
prev_features /= 2
results.insert(0, output_conv(prev_features))
if self.top_block is not None:
if self.top_block.in_feature in bottom_up_features:
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
else:
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
results.extend(self.top_block(top_block_in_feature))
assert len(self._out_features) == len(results)
return {f: res for f, res in zip(self._out_features, results)}
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def _assert_strides_are_log2_contiguous(strides):
"""
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
"""
for i, stride in enumerate(strides[1:], 1):
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(
stride, strides[i - 1]
)
class LastLevelMaxPool(nn.Module):
"""
This module is used in the original FPN to generate a downsampled
P6 feature from P5.
"""
def __init__(self):
super().__init__()
self.num_levels = 1
self.in_feature = "p5"
def forward(self, x):
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7 from
C5 feature.
"""
def __init__(self, in_channels, out_channels, in_feature="res5"):
super().__init__()
self.num_levels = 2
self.in_feature = in_feature
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
weight_init.c2_xavier_fill(module)
def forward(self, c5):
p6 = self.p6(c5)
p7 = self.p7(F.relu(p6))
return [p6, p7]
@BACKBONE_REGISTRY.register()
def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_resnet_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
@BACKBONE_REGISTRY.register()
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_resnet_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
in_channels_p6p7 = bottom_up.output_shape()["res5"].channels
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelP6P7(in_channels_p6p7, out_channels),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
import logging
import numpy as np
import torch
import torch.nn as nn
from .backbone import Backbone
from .utils import (
PatchEmbed,
add_decomposed_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
)
logger = logging.getLogger(__name__)
__all__ = ["MViT"]
def attention_pool(x, pool, norm=None):
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
x = pool(x)
# (B, C, H1, W1) -> (B, H1, W1, C)
x = x.permute(0, 2, 3, 1)
if norm:
x = norm(x)
return x
class MultiScaleAttention(nn.Module):
"""Multiscale Multi-head Attention block."""
def __init__(
self,
dim,
dim_out,
num_heads,
qkv_bias=True,
norm_layer=nn.LayerNorm,
pool_kernel=(3, 3),
stride_q=1,
stride_kv=1,
residual_pooling=True,
window_size=0,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
dim_out (int): Number of output channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
pool_kernel (tuple): kernel size for qkv pooling layers.
stride_q (int): stride size for q pooling layer.
stride_kv (int): stride size for kv pooling layer.
residual_pooling (bool): If true, enable residual pooling.
use_rel_pos (bool): If True, add relative postional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# qkv pooling
pool_padding = [k // 2 for k in pool_kernel]
dim_conv = dim_out // num_heads
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
pool_kernel,
stride=stride_q,
padding=pool_padding,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
pool_kernel,
stride=stride_kv,
padding=pool_padding,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
pool_kernel,
stride=stride_kv,
padding=pool_padding,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
self.window_size = window_size
if window_size:
self.q_win_size = window_size // stride_q
self.kv_win_size = window_size // stride_kv
self.residual_pooling = residual_pooling
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
# initialize relative positional embeddings
assert input_size[0] == input_size[1]
size = input_size[0]
rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim))
if not rel_pos_zero_init:
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H, W, C)
qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, -1).permute(3, 0, 4, 1, 2, 5)
# q, k, v with shape (B * nHead, H, W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H, W, -1).unbind(0)
q = attention_pool(q, self.pool_q, self.norm_q)
k = attention_pool(k, self.pool_k, self.norm_k)
v = attention_pool(v, self.pool_v, self.norm_v)
ori_q = q
if self.window_size:
q, q_hw_pad = window_partition(q, self.q_win_size)
k, kv_hw_pad = window_partition(k, self.kv_win_size)
v, _ = window_partition(v, self.kv_win_size)
q_hw = (self.q_win_size, self.q_win_size)
kv_hw = (self.kv_win_size, self.kv_win_size)
else:
q_hw = q.shape[1:3]
kv_hw = k.shape[1:3]
q = q.view(q.shape[0], np.prod(q_hw), -1)
k = k.view(k.shape[0], np.prod(kv_hw), -1)
v = v.view(v.shape[0], np.prod(kv_hw), -1)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, q_hw, kv_hw)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.view(x.shape[0], q_hw[0], q_hw[1], -1)
if self.window_size:
x = window_unpartition(x, self.q_win_size, q_hw_pad, ori_q.shape[1:3])
if self.residual_pooling:
x += ori_q
H, W = x.shape[1], x.shape[2]
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class MultiScaleBlock(nn.Module):
"""Multiscale Transformer blocks"""
def __init__(
self,
dim,
dim_out,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
qkv_pool_kernel=(3, 3),
stride_q=1,
stride_kv=1,
residual_pooling=True,
window_size=0,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
dim_out (int): Number of output channels.
num_heads (int): Number of attention heads in the MViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
qkv_pool_kernel (tuple): kernel size for qkv pooling layers.
stride_q (int): stride size for q pooling layer.
stride_kv (int): stride size for kv pooling layer.
residual_pooling (bool): If true, enable residual pooling.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
use_rel_pos (bool): If True, add relative postional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
pool_kernel=qkv_pool_kernel,
stride_q=stride_q,
stride_kv=stride_kv,
residual_pooling=residual_pooling,
window_size=window_size,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size,
)
from timm.models.layers import DropPath, Mlp
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
in_features=dim_out,
hidden_features=int(dim_out * mlp_ratio),
out_features=dim_out,
act_layer=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
if stride_q > 1:
kernel_skip = stride_q + 1
padding_skip = int(kernel_skip // 2)
self.pool_skip = nn.MaxPool2d(kernel_skip, stride_q, padding_skip, ceil_mode=False)
def forward(self, x):
x_norm = self.norm1(x)
x_block = self.attn(x_norm)
if hasattr(self, "proj"):
x = self.proj(x_norm)
if hasattr(self, "pool_skip"):
x = attention_pool(x, self.pool_skip)
x = x + self.drop_path(x_block)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class MViT(Backbone):
"""
This module implements Multiscale Vision Transformer (MViT) backbone in :paper:'mvitv2'.
"""
def __init__(
self,
img_size=224,
patch_kernel=(7, 7),
patch_stride=(4, 4),
patch_padding=(3, 3),
in_chans=3,
embed_dim=96,
depth=16,
num_heads=1,
last_block_indexes=(0, 2, 11, 15),
qkv_pool_kernel=(3, 3),
adaptive_kv_stride=4,
adaptive_window_size=56,
residual_pooling=True,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_abs_pos=False,
use_rel_pos=True,
rel_pos_zero_init=True,
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_features=("scale2", "scale3", "scale4", "scale5"),
):
"""
Args:
img_size (int): Input image size.
patch_kernel (tuple): kernel size for patch embedding.
patch_stride (tuple): stride size for patch embedding.
patch_padding (tuple): padding size for patch embedding.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of MViT.
num_heads (int): Number of base attention heads in each MViT block.
last_block_indexes (tuple): Block indexes for last blocks in each stage.
qkv_pool_kernel (tuple): kernel size for qkv pooling layers.
adaptive_kv_stride (int): adaptive stride size for kv pooling.
adaptive_window_size (int): adaptive window size for window attention blocks.
residual_pooling (bool): If true, enable residual pooling.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative postional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
out_features (tuple): name of the feature maps from each stage.
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_embed = PatchEmbed(
kernel_size=patch_kernel,
stride=patch_stride,
padding=patch_padding,
in_chans=in_chans,
embed_dim=embed_dim,
)
if use_abs_pos:
# Initialize absoluate positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_stride[0]) * (
pretrain_img_size // patch_stride[1]
)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
dim_out = embed_dim
stride_kv = adaptive_kv_stride
window_size = adaptive_window_size
input_size = (img_size // patch_stride[0], img_size // patch_stride[1])
stage = 2
stride = patch_stride[0]
self._out_feature_strides = {}
self._out_feature_channels = {}
self.blocks = nn.ModuleList()
for i in range(depth):
# Multiply stride_kv by 2 if it's the last block of stage2 and stage3.
if i == last_block_indexes[1] or i == last_block_indexes[2]:
stride_kv_ = stride_kv * 2
else:
stride_kv_ = stride_kv
# hybrid window attention: global attention in last three stages.
window_size_ = 0 if i in last_block_indexes[1:] else window_size
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
qkv_pool_kernel=qkv_pool_kernel,
stride_q=2 if i - 1 in last_block_indexes else 1,
stride_kv=stride_kv_,
residual_pooling=residual_pooling,
window_size=window_size_,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size,
)
if use_act_checkpoint:
# TODO: use torch.utils.checkpoint
from fairscale.nn.checkpoint import checkpoint_wrapper
block = checkpoint_wrapper(block)
self.blocks.append(block)
embed_dim = dim_out
if i in last_block_indexes:
name = f"scale{stage}"
if name in out_features:
self._out_feature_channels[name] = dim_out
self._out_feature_strides[name] = stride
self.add_module(f"{name}_norm", norm_layer(dim_out))
dim_out *= 2
num_heads *= 2
stride_kv = max(stride_kv // 2, 1)
stride *= 2
stage += 1
if i - 1 in last_block_indexes:
window_size = window_size // 2
input_size = [s // 2 for s in input_size]
self._out_features = out_features
self._last_block_indexes = last_block_indexes
if self.pos_embed is not None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token, x.shape[1:3])
outputs = {}
stage = 2
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self._last_block_indexes:
name = f"scale{stage}"
if name in self._out_features:
x_out = getattr(self, f"{name}_norm")(x)
outputs[name] = x_out.permute(0, 3, 1, 2)
stage += 1
return outputs
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Implementation of RegNet models from :paper:`dds` and :paper:`scaling`.
This code is adapted from https://github.com/facebookresearch/pycls with minimal modifications.
Some code duplication exists between RegNet and ResNets (e.g., ResStem) in order to simplify
model loading.
"""
import numpy as np
from torch import nn
from detectron2.layers import CNNBlockBase, ShapeSpec, get_norm
from .backbone import Backbone
__all__ = [
"AnyNet",
"RegNet",
"ResStem",
"SimpleStem",
"VanillaBlock",
"ResBasicBlock",
"ResBottleneckBlock",
]
def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False):
"""Helper for building a conv2d layer."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
s, p, g, b = stride, (k - 1) // 2, groups, bias
return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b)
def gap2d():
"""Helper for building a global average pooling layer."""
return nn.AdaptiveAvgPool2d((1, 1))
def pool2d(k, *, stride=1):
"""Helper for building a pool2d layer."""
assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues."
return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2)
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
class ResStem(CNNBlockBase):
"""ResNet stem for ImageNet: 7x7, BN, AF, MaxPool."""
def __init__(self, w_in, w_out, norm, activation_class):
super().__init__(w_in, w_out, 4)
self.conv = conv2d(w_in, w_out, 7, stride=2)
self.bn = get_norm(norm, w_out)
self.af = activation_class()
self.pool = pool2d(3, stride=2)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class SimpleStem(CNNBlockBase):
"""Simple stem for ImageNet: 3x3, BN, AF."""
def __init__(self, w_in, w_out, norm, activation_class):
super().__init__(w_in, w_out, 2)
self.conv = conv2d(w_in, w_out, 3, stride=2)
self.bn = get_norm(norm, w_out)
self.af = activation_class()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, Act, FC, Sigmoid."""
def __init__(self, w_in, w_se, activation_class):
super().__init__()
self.avg_pool = gap2d()
self.f_ex = nn.Sequential(
conv2d(w_in, w_se, 1, bias=True),
activation_class(),
conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
class VanillaBlock(CNNBlockBase):
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, norm, activation_class, _params):
super().__init__(w_in, w_out, stride)
self.a = conv2d(w_in, w_out, 3, stride=stride)
self.a_bn = get_norm(norm, w_out)
self.a_af = activation_class()
self.b = conv2d(w_out, w_out, 3)
self.b_bn = get_norm(norm, w_out)
self.b_af = activation_class()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, norm, activation_class, _params):
super().__init__()
self.a = conv2d(w_in, w_out, 3, stride=stride)
self.a_bn = get_norm(norm, w_out)
self.a_af = activation_class()
self.b = conv2d(w_out, w_out, 3)
self.b_bn = get_norm(norm, w_out)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResBasicBlock(CNNBlockBase):
"""Residual basic block: x + f(x), f = basic transform."""
def __init__(self, w_in, w_out, stride, norm, activation_class, params):
super().__init__(w_in, w_out, stride)
self.proj, self.bn = None, None
if (w_in != w_out) or (stride != 1):
self.proj = conv2d(w_in, w_out, 1, stride=stride)
self.bn = get_norm(norm, w_out)
self.f = BasicTransform(w_in, w_out, stride, norm, activation_class, params)
self.af = activation_class()
def forward(self, x):
x_p = self.bn(self.proj(x)) if self.proj else x
return self.af(x_p + self.f(x))
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
def __init__(self, w_in, w_out, stride, norm, activation_class, params):
super().__init__()
w_b = int(round(w_out * params["bot_mul"]))
w_se = int(round(w_in * params["se_r"]))
groups = w_b // params["group_w"]
self.a = conv2d(w_in, w_b, 1)
self.a_bn = get_norm(norm, w_b)
self.a_af = activation_class()
self.b = conv2d(w_b, w_b, 3, stride=stride, groups=groups)
self.b_bn = get_norm(norm, w_b)
self.b_af = activation_class()
self.se = SE(w_b, w_se, activation_class) if w_se else None
self.c = conv2d(w_b, w_out, 1)
self.c_bn = get_norm(norm, w_out)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class ResBottleneckBlock(CNNBlockBase):
"""Residual bottleneck block: x + f(x), f = bottleneck transform."""
def __init__(self, w_in, w_out, stride, norm, activation_class, params):
super().__init__(w_in, w_out, stride)
self.proj, self.bn = None, None
if (w_in != w_out) or (stride != 1):
self.proj = conv2d(w_in, w_out, 1, stride=stride)
self.bn = get_norm(norm, w_out)
self.f = BottleneckTransform(w_in, w_out, stride, norm, activation_class, params)
self.af = activation_class()
def forward(self, x):
x_p = self.bn(self.proj(x)) if self.proj else x
return self.af(x_p + self.f(x))
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_class, norm, activation_class, params):
super().__init__()
for i in range(d):
block = block_class(w_in, w_out, stride, norm, activation_class, params)
self.add_module("b{}".format(i + 1), block)
stride, w_in = 1, w_out
def forward(self, x):
for block in self.children():
x = block(x)
return x
class AnyNet(Backbone):
"""AnyNet model. See :paper:`dds`."""
def __init__(
self,
*,
stem_class,
stem_width,
block_class,
depths,
widths,
group_widths,
strides,
bottleneck_ratios,
se_ratio,
activation_class,
freeze_at=0,
norm="BN",
out_features=None,
):
"""
Args:
stem_class (callable): A callable taking 4 arguments (channels in, channels out,
normalization, callable returning an activation function) that returns another
callable implementing the stem module.
stem_width (int): The number of output channels that the stem produces.
block_class (callable): A callable taking 6 arguments (channels in, channels out,
stride, normalization, callable returning an activation function, a dict of
block-specific parameters) that returns another callable implementing the repeated
block module.
depths (list[int]): Number of blocks in each stage.
widths (list[int]): For each stage, the number of output channels of each block.
group_widths (list[int]): For each stage, the number of channels per group in group
convolution, if the block uses group convolution.
strides (list[int]): The stride that each network stage applies to its input.
bottleneck_ratios (list[float]): For each stage, the ratio of the number of bottleneck
channels to the number of block input channels (or, equivalently, output channels),
if the block uses a bottleneck.
se_ratio (float): The ratio of the number of channels used inside the squeeze-excitation
(SE) module to it number of input channels, if SE the block uses SE.
activation_class (callable): A callable taking no arguments that returns another
callable implementing an activation function.
freeze_at (int): The number of stages at the beginning to freeze.
see :meth:`freeze` for detailed explanation.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
out_features (list[str]): name of the layers whose outputs should
be returned in forward. RegNet's use "stem" and "s1", "s2", etc for the stages after
the stem. If None, will return the output of the last layer.
"""
super().__init__()
self.stem = stem_class(3, stem_width, norm, activation_class)
current_stride = self.stem.stride
self._out_feature_strides = {"stem": current_stride}
self._out_feature_channels = {"stem": self.stem.out_channels}
self.stages_and_names = []
prev_w = stem_width
for i, (d, w, s, b, g) in enumerate(
zip(depths, widths, strides, bottleneck_ratios, group_widths)
):
params = {"bot_mul": b, "group_w": g, "se_r": se_ratio}
stage = AnyStage(prev_w, w, s, d, block_class, norm, activation_class, params)
name = "s{}".format(i + 1)
self.add_module(name, stage)
self.stages_and_names.append((stage, name))
self._out_feature_strides[name] = current_stride = int(
current_stride * np.prod([k.stride for k in stage.children()])
)
self._out_feature_channels[name] = list(stage.children())[-1].out_channels
prev_w = w
self.apply(init_weights)
if out_features is None:
out_features = [name]
self._out_features = out_features
assert len(self._out_features)
children = [x[0] for x in self.named_children()]
for out_feature in self._out_features:
assert out_feature in children, "Available children: {} does not include {}".format(
", ".join(children), out_feature
)
self.freeze(freeze_at)
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
assert x.dim() == 4, f"Model takes an input of shape (N, C, H, W). Got {x.shape} instead!"
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for stage, name in self.stages_and_names:
x = stage(x)
if name in self._out_features:
outputs[name] = x
return outputs
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def freeze(self, freeze_at=0):
"""
Freeze the first several stages of the model. Commonly used in fine-tuning.
Layers that produce the same feature map spatial size are defined as one
"stage" by :paper:`FPN`.
Args:
freeze_at (int): number of stages to freeze.
`1` means freezing the stem. `2` means freezing the stem and
one residual stage, etc.
Returns:
nn.Module: this model itself
"""
if freeze_at >= 1:
self.stem.freeze()
for idx, (stage, _) in enumerate(self.stages_and_names, start=2):
if freeze_at >= idx:
for block in stage.children():
block.freeze()
return self
def adjust_block_compatibility(ws, bs, gs):
"""Adjusts the compatibility of widths, bottlenecks, and groups."""
assert len(ws) == len(bs) == len(gs)
assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs))
vs = [int(max(1, w * b)) for w, b in zip(ws, bs)]
gs = [int(min(g, v)) for g, v in zip(gs, vs)]
ms = [np.lcm(g, b) if b > 1 else g for g, b in zip(gs, bs)]
vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)]
ws = [int(v / b) for v, b in zip(vs, bs)]
assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs))
return ws, bs, gs
def generate_regnet_parameters(w_a, w_0, w_m, d, q=8):
"""Generates per stage widths and depths from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
# Generate continuous per-block ws
ws_cont = np.arange(d) * w_a + w_0
# Generate quantized per-block ws
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws_all = w_0 * np.power(w_m, ks)
ws_all = np.round(np.divide(ws_all, q)).astype(int) * q
# Generate per stage ws and ds (assumes ws_all are sorted)
ws, ds = np.unique(ws_all, return_counts=True)
# Compute number of actual stages and total possible stages
num_stages, total_stages = len(ws), ks.max() + 1
# Convert numpy arrays to lists and return
ws, ds, ws_all, ws_cont = (x.tolist() for x in (ws, ds, ws_all, ws_cont))
return ws, ds, num_stages, total_stages, ws_all, ws_cont
class RegNet(AnyNet):
"""RegNet model. See :paper:`dds`."""
def __init__(
self,
*,
stem_class,
stem_width,
block_class,
depth,
w_a,
w_0,
w_m,
group_width,
stride=2,
bottleneck_ratio=1.0,
se_ratio=0.0,
activation_class=None,
freeze_at=0,
norm="BN",
out_features=None,
):
"""
Build a RegNet from the parameterization described in :paper:`dds` Section 3.3.
Args:
See :class:`AnyNet` for arguments that are not listed here.
depth (int): Total number of blocks in the RegNet.
w_a (float): Factor by which block width would increase prior to quantizing block widths
by stage. See :paper:`dds` Section 3.3.
w_0 (int): Initial block width. See :paper:`dds` Section 3.3.
w_m (float): Parameter controlling block width quantization.
See :paper:`dds` Section 3.3.
group_width (int): Number of channels per group in group convolution, if the block uses
group convolution.
bottleneck_ratio (float): The ratio of the number of bottleneck channels to the number
of block input channels (or, equivalently, output channels), if the block uses a
bottleneck.
stride (int): The stride that each network stage applies to its input.
"""
ws, ds = generate_regnet_parameters(w_a, w_0, w_m, depth)[0:2]
ss = [stride for _ in ws]
bs = [bottleneck_ratio for _ in ws]
gs = [group_width for _ in ws]
ws, bs, gs = adjust_block_compatibility(ws, bs, gs)
def default_activation_class():
return nn.ReLU(inplace=True)
super().__init__(
stem_class=stem_class,
stem_width=stem_width,
block_class=block_class,
depths=ds,
widths=ws,
strides=ss,
group_widths=gs,
bottleneck_ratios=bs,
se_ratio=se_ratio,
activation_class=(
default_activation_class if activation_class is None else activation_class
),
freeze_at=freeze_at,
norm=norm,
out_features=out_features,
)
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import (
CNNBlockBase,
Conv2d,
DeformConv,
ModulatedDeformConv,
ShapeSpec,
get_norm,
)
from .backbone import Backbone
from .build import BACKBONE_REGISTRY
__all__ = [
"ResNetBlockBase",
"BasicBlock",
"BottleneckBlock",
"DeformBottleneckBlock",
"BasicStem",
"ResNet",
"make_stage",
"build_resnet_backbone",
]
class BasicBlock(CNNBlockBase):
"""
The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
with two 3x3 conv layers and a projection shortcut if needed.
"""
def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride for the first conv.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
norm=get_norm(norm, out_channels),
)
self.conv2 = Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class BottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block used by ResNet-50, 101 and 152
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
1x1, 3x3, 1x1, and a projection shortcut if needed.
"""
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
):
"""
Args:
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
num_groups (int): number of groups for the 3x3 conv layer.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
stride_in_1x1 (bool): when stride>1, whether to put stride in the
first 1x1 convolution or the bottleneck 3x3 convolution.
dilation (int): the dilation rate of the 3x3 conv layer.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
# The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
# Zero-initialize the last normalization in each residual branch,
# so that at the beginning, the residual branch starts with zeros,
# and each residual block behaves like an identity.
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "For BN layers, the learnable scaling coefficient γ is initialized
# to be 1, except for each residual block's last BN
# where γ is initialized to be 0."
# nn.init.constant_(self.conv3.norm.weight, 0)
# TODO this somehow hurts performance when training GN models from scratch.
# Add it as an option when we need to use this code to train a backbone.
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class DeformBottleneckBlock(CNNBlockBase):
"""
Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
in the 3x3 convolution.
"""
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
deform_modulated=False,
deform_num_groups=1,
):
super().__init__(in_channels, out_channels, stride)
self.deform_modulated = deform_modulated
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
if deform_modulated:
deform_conv_op = ModulatedDeformConv
# offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
offset_channels = 27
else:
deform_conv_op = DeformConv
offset_channels = 18
self.conv2_offset = Conv2d(
bottleneck_channels,
offset_channels * deform_num_groups,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
dilation=dilation,
)
self.conv2 = deform_conv_op(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
deformable_groups=deform_num_groups,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
nn.init.constant_(self.conv2_offset.weight, 0)
nn.init.constant_(self.conv2_offset.bias, 0)
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
if self.deform_modulated:
offset_mask = self.conv2_offset(out)
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((offset_x, offset_y), dim=1)
mask = mask.sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class BasicStem(CNNBlockBase):
"""
The standard ResNet stem (layers before the first residual block),
with a conv, relu and max_pool.
"""
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
"""
Args:
norm (str or callable): norm after the first conv layer.
See :func:`layers.get_norm` for supported format.
"""
super().__init__(in_channels, out_channels, 4)
self.in_channels = in_channels
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
norm=get_norm(norm, out_channels),
)
weight_init.c2_msra_fill(self.conv1)
def forward(self, x):
x = self.conv1(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x
class ResNet(Backbone):
"""
Implement :paper:`ResNet`.
"""
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
"""
Args:
stem (nn.Module): a stem module
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
each contains multiple :class:`CNNBlockBase`.
num_classes (None or int): if None, will not perform classification.
Otherwise, will create a linear layer.
out_features (list[str]): name of the layers whose outputs should
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
If None, will return the output of the last layer.
freeze_at (int): The number of stages at the beginning to freeze.
see :meth:`freeze` for detailed explanation.
"""
super().__init__()
self.stem = stem
self.num_classes = num_classes
current_stride = self.stem.stride
self._out_feature_strides = {"stem": current_stride}
self._out_feature_channels = {"stem": self.stem.out_channels}
self.stage_names, self.stages = [], []
if out_features is not None:
# Avoid keeping unused layers in this module. They consume extra memory
# and may cause allreduce to fail
num_stages = max(
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
)
stages = stages[:num_stages]
for i, blocks in enumerate(stages):
assert len(blocks) > 0, len(blocks)
for block in blocks:
assert isinstance(block, CNNBlockBase), block
name = "res" + str(i + 2)
stage = nn.Sequential(*blocks)
self.add_module(name, stage)
self.stage_names.append(name)
self.stages.append(stage)
self._out_feature_strides[name] = current_stride = int(
current_stride * np.prod([k.stride for k in blocks])
)
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
self.stage_names = tuple(self.stage_names) # Make it static for scripting
if num_classes is not None:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(curr_channels, num_classes)
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "The 1000-way fully-connected layer is initialized by
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
nn.init.normal_(self.linear.weight, std=0.01)
name = "linear"
if out_features is None:
out_features = [name]
self._out_features = out_features
assert len(self._out_features)
children = [x[0] for x in self.named_children()]
for out_feature in self._out_features:
assert out_feature in children, "Available children: {}".format(", ".join(children))
self.freeze(freeze_at)
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for name, stage in zip(self.stage_names, self.stages):
x = stage(x)
if name in self._out_features:
outputs[name] = x
if self.num_classes is not None:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.linear(x)
if "linear" in self._out_features:
outputs["linear"] = x
return outputs
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def freeze(self, freeze_at=0):
"""
Freeze the first several stages of the ResNet. Commonly used in
fine-tuning.
Layers that produce the same feature map spatial size are defined as one
"stage" by :paper:`FPN`.
Args:
freeze_at (int): number of stages to freeze.
`1` means freezing the stem. `2` means freezing the stem and
one residual stage, etc.
Returns:
nn.Module: this ResNet itself
"""
if freeze_at >= 1:
self.stem.freeze()
for idx, stage in enumerate(self.stages, start=2):
if freeze_at >= idx:
for block in stage.children():
block.freeze()
return self
@staticmethod
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
"""
Create a list of blocks of the same type that forms one ResNet stage.
Args:
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
stage. A module of this type must not change spatial resolution of inputs unless its
stride != 1.
num_blocks (int): number of blocks in this stage
in_channels (int): input channels of the entire stage.
out_channels (int): output channels of **every block** in the stage.
kwargs: other arguments passed to the constructor of
`block_class`. If the argument name is "xx_per_block", the
argument is a list of values to be passed to each block in the
stage. Otherwise, the same argument is passed to every block
in the stage.
Returns:
list[CNNBlockBase]: a list of block module.
Examples:
::
stage = ResNet.make_stage(
BottleneckBlock, 3, in_channels=16, out_channels=64,
bottleneck_channels=16, num_groups=1,
stride_per_block=[2, 1, 1],
dilations_per_block=[1, 1, 2]
)
Usually, layers that produce the same feature map spatial size are defined as one
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
all be 1.
"""
blocks = []
for i in range(num_blocks):
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
assert len(v) == num_blocks, (
f"Argument '{k}' of make_stage should have the "
f"same length as num_blocks={num_blocks}."
)
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
else:
curr_kwargs[k] = v
blocks.append(
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
)
in_channels = out_channels
return blocks
@staticmethod
def make_default_stages(depth, block_class=None, **kwargs):
"""
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
instead for fine-grained customization.
Args:
depth (int): depth of ResNet
block_class (type): the CNN block class. Has to accept
`bottleneck_channels` argument for depth > 50.
By default it is BasicBlock or BottleneckBlock, based on the
depth.
kwargs:
other arguments to pass to `make_stage`. Should not contain
stride and channels, as they are predefined for each depth.
Returns:
list[list[CNNBlockBase]]: modules in all stages; see arguments of
:class:`ResNet.__init__`.
"""
num_blocks_per_stage = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[depth]
if block_class is None:
block_class = BasicBlock if depth < 50 else BottleneckBlock
if depth < 50:
in_channels = [64, 64, 128, 256]
out_channels = [64, 128, 256, 512]
else:
in_channels = [64, 256, 512, 1024]
out_channels = [256, 512, 1024, 2048]
ret = []
for n, s, i, o in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
if depth >= 50:
kwargs["bottleneck_channels"] = o // 4
ret.append(
ResNet.make_stage(
block_class=block_class,
num_blocks=n,
stride_per_block=[s] + [1] * (n - 1),
in_channels=i,
out_channels=o,
**kwargs,
)
)
return ret
ResNetBlockBase = CNNBlockBase
"""
Alias for backward compatibiltiy.
"""
def make_stage(*args, **kwargs):
"""
Deprecated alias for backward compatibiltiy.
"""
return ResNet.make_stage(*args, **kwargs)
@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg, input_shape):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
# fmt: off
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
# fmt: on
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
num_blocks_per_stage = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[depth]
if depth in [18, 34]:
assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
assert not any(
deform_on_per_stage
), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
stages = []
for idx, stage_idx in enumerate(range(2, 6)):
# res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
"in_channels": in_channels,
"out_channels": out_channels,
"norm": norm,
}
# Use BasicBlock for R18 and R34.
if depth in [18, 34]:
stage_kargs["block_class"] = BasicBlock
else:
stage_kargs["bottleneck_channels"] = bottleneck_channels
stage_kargs["stride_in_1x1"] = stride_in_1x1
stage_kargs["dilation"] = dilation
stage_kargs["num_groups"] = num_groups
if deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock
blocks = ResNet.make_stage(**stage_kargs)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Implementation of Swin models from :paper:`swin`.
This code is adapted from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py with minimal modifications. # noqa
--------------------------------------------------------
Swin Transformer
Copyright (c) 2021 Microsoft
Licensed under The MIT License [see LICENSE for details]
Written by Ze Liu, Yutong Lin, Yixuan Wei
--------------------------------------------------------
LICENSE: https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/461e003166a8083d0b620beacd4662a2df306bd6/LICENSE
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from detectron2.modeling.backbone.backbone import Backbone
_to_2tuple = nn.modules.utils._ntuple(2)
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value.
Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=_to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
if drop_path > 0.0:
from timm.models.layers import DropPath
self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = _to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(Backbone):
"""Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted
Windows` - https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False,
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
pretrain_img_size = _to_2tuple(pretrain_img_size)
patch_size = _to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
nn.init.trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
self._out_features = ["p{}".format(i) for i in self.out_indices]
self._out_feature_channels = {
"p{}".format(i): self.embed_dim * 2**i for i in self.out_indices
}
self._out_feature_strides = {"p{}".format(i): 2 ** (i + 2) for i in self.out_indices}
self._size_devisibility = 32
self.apply(self._init_weights)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@property
def size_divisibility(self):
return self._size_divisibility
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = {}
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs["p{}".format(i)] = out
return outs
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"window_partition",
"window_unpartition",
"add_decomposed_rel_pos",
"get_abs_pos",
"PatchEmbed",
]
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows, window_size, pad_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
def get_abs_pos(abs_pos, has_cls_token, hw):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
if has_cls_token:
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = F.interpolate(
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1)
else:
return abs_pos.reshape(1, h, w, -1)
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x):
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
import logging
import math
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn as nn
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
from .backbone import Backbone
from .utils import (
PatchEmbed,
add_decomposed_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
)
logger = logging.getLogger(__name__)
__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
if not rel_pos_zero_init:
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class ResBottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
norm="LN",
act_layer=nn.GELU,
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__(in_channels, out_channels, 1)
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = get_norm(norm, bottleneck_channels)
self.act1 = act_layer()
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
3,
padding=1,
bias=False,
)
self.norm2 = get_norm(norm, bottleneck_channels)
self.act2 = act_layer()
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = get_norm(norm, out_channels)
for layer in [self.conv1, self.conv2, self.conv3]:
weight_init.c2_msra_fill(layer)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
def forward(self, x):
out = x
for layer in self.children():
out = layer(out)
out = x + out
return out
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
use_residual_block=False,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
use_residual_block (bool): If True, use a residual block after the MLP block.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
from timm.models.layers import DropPath, Mlp
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
self.window_size = window_size
self.use_residual_block = use_residual_block
if use_residual_block:
# Use a residual block with bottleneck channel as dim // 2
self.residual = ResBottleneckBlock(
in_channels=dim,
out_channels=dim,
bottleneck_channels=dim // 2,
norm="LN",
act_layer=act_layer,
)
def forward(self, x):
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.use_residual_block:
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
return x
class ViT(Backbone):
"""
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
"Exploring Plain Vision Transformer Backbones for Object Detection",
https://arxiv.org/abs/2203.16527
"""
def __init__(
self,
img_size=1024,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_abs_pos=True,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
window_block_indexes=(),
residual_block_indexes=(),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_feature="last_feat",
):
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
window_block_indexes (list): Indexes for blocks using window attention.
residual_block_indexes (list): Indexes for blocks using conv propagation.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
out_feature (str): name of the feature from the last block.
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
use_residual_block=i in residual_block_indexes,
input_size=(img_size // patch_size, img_size // patch_size),
)
if use_act_checkpoint:
# TODO: use torch.utils.checkpoint
from fairscale.nn.checkpoint import checkpoint_wrapper
block = checkpoint_wrapper(block)
self.blocks.append(block)
self._out_feature_channels = {out_feature: embed_dim}
self._out_feature_strides = {out_feature: patch_size}
self._out_features = [out_feature]
if self.pos_embed is not None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
)
for blk in self.blocks:
x = blk(x)
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
return outputs
class SimpleFeaturePyramid(Backbone):
"""
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
It creates pyramid features built on top of the input feature map.
"""
def __init__(
self,
net,
in_feature,
out_channels,
scale_factors,
top_block=None,
norm="LN",
square_pad=0,
):
"""
Args:
net (Backbone): module representing the subnetwork backbone.
Must be a subclass of :class:`Backbone`.
in_feature (str): names of the input feature maps coming
from the net.
out_channels (int): number of channels in the output feature maps.
scale_factors (list[float]): list of scaling factors to upsample or downsample
the input features for creating pyramid features.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
pyramid output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra pyramid levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
norm (str): the normalization to use.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(SimpleFeaturePyramid, self).__init__()
assert isinstance(net, Backbone)
self.scale_factors = scale_factors
input_shapes = net.output_shape()
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
_assert_strides_are_log2_contiguous(strides)
dim = input_shapes[in_feature].channels
self.stages = []
use_bias = norm == ""
for idx, scale in enumerate(scale_factors):
out_dim = dim
if scale == 4.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
get_norm(norm, dim // 2),
nn.GELU(),
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
]
out_dim = dim // 4
elif scale == 2.0:
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
out_dim = dim // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
layers.extend(
[
Conv2d(
out_dim,
out_channels,
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
]
)
layers = nn.Sequential(*layers)
stage = int(math.log2(strides[idx]))
self.add_module(f"simfp_{stage}", layers)
self.stages.append(layers)
self.net = net
self.in_feature = in_feature
self.top_block = top_block
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
self._square_pad = square_pad
@property
def padding_constraints(self):
return {
"size_divisiblity": self._size_divisibility,
"square_size": self._square_pad,
}
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]:
mapping from feature map name to pyramid feature map tensor
in high to low resolution order. Returned feature names follow the FPN
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
["p2", "p3", ..., "p6"].
"""
bottom_up_features = self.net(x)
features = bottom_up_features[self.in_feature]
results = []
for stage in self.stages:
results.append(stage(features))
if self.top_block is not None:
if self.top_block.in_feature in bottom_up_features:
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
else:
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
results.extend(self.top_block(top_block_in_feature))
assert len(self._out_features) == len(results)
return {f: res for f, res in zip(self._out_features, results)}
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
"""
Calculate lr decay rate for different ViT blocks.
Args:
name (string): parameter name.
lr_decay_rate (float): base lr decay rate.
num_layers (int): number of ViT blocks.
Returns:
lr decay rate for the given parameter.
"""
layer_id = num_layers + 1
if name.startswith("backbone"):
if ".pos_embed" in name or ".patch_embed" in name:
layer_id = 0
elif ".blocks." in name and ".residual." not in name:
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment