Commit 1345fab2 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1263 canceled with stages
This source diff could not be displayed because it is too large. You can view the blob instead.
import argparse
import os
from functools import partial
from test import create_test_data_loader
from typing import Dict, List, Tuple
import accelerate
import cv2
import numpy as np
import torch
import torch.utils.data as data
from accelerate import Accelerator
from PIL import Image
from tqdm import tqdm
from util.lazy_load import Config
from util.logger import setup_logger
from util.utils import load_checkpoint, load_state_dict
from util.visualize import plot_bounding_boxes_on_image_cv2
def is_image(file_path):
try:
img = Image.open(file_path)
img.close()
return True
except:
return False
def parse_args():
parser = argparse.ArgumentParser(description="Inference a detector")
# dataset parameters
parser.add_argument("--image-dir", type=str, required=True)
parser.add_argument("--workers", type=int, default=0)
# model parameters
parser.add_argument("--model-config", type=str, required=True)
parser.add_argument("--checkpoint", type=str, required=True)
# visualization parameters
parser.add_argument("--show-dir", type=str, default=None)
parser.add_argument("--show-conf", type=float, default=0.5)
# plot parameters
parser.add_argument("--font-scale", type=float, default=1.0)
parser.add_argument("--box-thick", type=int, default=1)
parser.add_argument("--fill-alpha", type=float, default=0.2)
parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255))
parser.add_argument("--text-font-color", type=int, nargs="+", default=None)
parser.add_argument("--text-alpha", type=float, default=1.0)
# engine parameters
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return args
class InferenceDataset(data.Dataset):
def __init__(self, root):
self.images = [os.path.join(root, img) for img in os.listdir(root)]
self.images = [img for img in self.images if is_image(img)]
assert len(self.images) > 0, "No images found"
def __len__(self):
return len(self.images)
def __getitem__(self, index):
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
image = cv2.imdecode(np.fromfile(self.images[index], dtype=np.uint8), -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
return torch.tensor(image)
def inference():
args = parse_args()
# set fixed seed and deterministic_algorithms
accelerator = Accelerator()
accelerate.utils.set_seed(args.seed, device_specific=False)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# deterministic in low version pytorch leads to RuntimeError
# torch.use_deterministic_algorithms(True, warn_only=True)
# setup logger
for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]:
setup_logger(distributed_rank=accelerator.local_process_index, name=logger_name)
dataset = InferenceDataset(args.image_dir)
data_loader = create_test_data_loader(
dataset, accelerator=accelerator, batch_size=1, num_workers=args.workers
)
# get inference results from model output
model = Config(args.model_config).model.eval()
checkpoint = load_checkpoint(args.checkpoint)
if isinstance(checkpoint, Dict) and "model" in checkpoint:
checkpoint = checkpoint["model"]
load_state_dict(model, checkpoint)
model = accelerator.prepare_model(model)
with torch.inference_mode():
predictions = []
for index, images in enumerate(tqdm(data_loader)):
prediction = model(images)[0]
# change torch.Tensor to CPU
for key in prediction:
prediction[key] = prediction[key].to("cpu", non_blocking=True)
image_name = data_loader.dataset.images[index]
image = images[0].to("cpu", non_blocking=True)
prediction = {"image_name": image_name, "image": image, "output": prediction}
predictions.append(prediction)
# save visualization results
if args.show_dir:
os.makedirs(args.show_dir, exist_ok=True)
# create a dummy dataset for visualization with multi-workers
data_loader = create_test_data_loader(
predictions, accelerator=accelerator, batch_size=1, num_workers=args.workers
)
data_loader.collate_fn = partial(_visualize_batch_for_infer, classes=model.CLASSES, **vars(args))
[None for _ in tqdm(data_loader)]
def _visualize_batch_for_infer(
batch: Tuple[Dict],
classes: List[str],
show_conf: float = 0.0,
show_dir: str = None,
font_scale: float = 1.0,
box_thick: int = 3,
fill_alpha: float = 0.2,
text_box_color: Tuple[int] = (255, 255, 255),
text_font_color: Tuple[int] = None,
text_alpha: float = 0.5,
**kwargs, # Not useful
):
image_name, image, output = batch[0].values()
# plot bounding boxes on image
image = image.numpy().transpose(1, 2, 0)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = plot_bounding_boxes_on_image_cv2(
image=image,
boxes=output["boxes"],
labels=output["labels"],
scores=output.get("scores", None),
classes=classes,
show_conf=show_conf,
font_scale=font_scale,
box_thick=box_thick,
fill_alpha=fill_alpha,
text_box_color=text_box_color,
text_font_color=text_font_color,
text_alpha=text_alpha,
)
cv2.imwrite(os.path.join(show_dir, os.path.basename(image_name)), image)
if __name__ == "__main__":
inference()
import argparse
import datetime
import os
import pprint
import re
import time
import accelerate
import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.tracking import TensorBoardTracker
from accelerate.utils import ProjectConfiguration
from torch.utils import data
from util.collate_fn import collate_fn
from util.engine import evaluate_acc, train_one_epoch_acc
from util.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from util.lazy_load import Config
from util.misc import default_setup, encode_labels, fixed_generator, seed_worker
from util.utils import HighestCheckpoint, load_checkpoint, load_state_dict
def parse_args():
parser = argparse.ArgumentParser(description="Train a detector")
parser.add_argument("--config-file", default="configs/train_config.py")
parser.add_argument(
"--mixed-precision",
type=str,
default=None,
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
parser.add_argument(
"--accumulate-steps", type=int, default=1, help="Steps to accumulate gradients"
)
parser.add_argument("--seed", type=int, help="Random seed")
parser.add_argument("--use-deterministic-algorithms", action="store_true")
dynamo_backend = ["no", "eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser"]
dynamo_backend += ["cudagraphs", "ofi", "fx2trt", "onnxrt", "tensorrt", "ipex", "tvm"]
parser.add_argument(
"--dynamo-backend",
type=str,
default="no",
choices=dynamo_backend,
help="""
Set to one of the possible dynamo backends to optimize the training with torch dynamo.
See https://pytorch.org/docs/stable/torch.compiler.html and
https://huggingface.co/docs/accelerate/main/en/package_reference/utilities#accelerate.utils.DynamoBackend
""",
)
args = parser.parse_args()
return args
def train():
args = parse_args()
cfg = Config(args.config_file, partials=("lr_scheduler", "optimizer", "param_dicts"))
# modify output directory
if getattr(cfg, "output_dir", None) is None:
if hasattr(cfg, "resume_from_checkpoint") and os.path.isdir(str(cfg.resume_from_checkpoint)):
# default path: xxxx-xx-xx-yy_yy_yy/checkpoints/{checkpoint_1}
if "checkpoints" in os.listdir(cfg.resume_from_checkpoint):
# if given output_dir, find the newest checkpoint under checkpoints directory
output_dir = os.path.join(cfg.resume_from_checkpoint, "checkpoints")
folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]
folders.sort(
key=lambda folder:
list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
)
cfg.resume_from_checkpoint = folders[-1]
if "checkpoints" in os.path.dirname(cfg.resume_from_checkpoint):
cfg.output_dir = os.path.dirname(os.path.dirname(cfg.resume_from_checkpoint))
else:
# make sure all processes have same output directory
accelerate.utils.wait_for_everyone()
cfg.output_dir = os.path.join(
"checkpoints",
os.path.basename(cfg.model_path).split(".")[0],
"train",
datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S"),
)
# Initialize accelerator
project_config = ProjectConfiguration(
project_dir=cfg.output_dir, total_limit=5, automatic_checkpoint_naming=True
)
tensorboard_tracker = TensorBoardTracker(run_name="tf_log", logging_dir=cfg.output_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=cfg.find_unused_parameters)
accelerator = Accelerator(
log_with=tensorboard_tracker,
project_config=project_config,
mixed_precision=args.mixed_precision,
gradient_accumulation_steps=args.accumulate_steps,
dynamo_backend=args.dynamo_backend,
step_scheduler_with_optimizer=False,
kwargs_handlers=[kwargs],
)
accelerator.init_trackers("det_train")
default_setup(args, cfg, accelerator)
# instantiate dataset
params = dict(num_workers=cfg.num_workers, collate_fn=collate_fn)
params.update(dict(pin_memory=cfg.pin_memory, persistent_workers=True))
if args.use_deterministic_algorithms:
# set using deterministic algorithms
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
params.update({"worker_init_fn": seed_worker, "generator": fixed_generator()})
# we use group_based sampler, which increases training speed slightly
group_ids = create_aspect_ratio_groups(cfg.train_dataset, k=3)
train_batch_sampler = GroupedBatchSampler(
data.RandomSampler(cfg.train_dataset), group_ids, cfg.batch_size
)
train_loader = data.DataLoader(cfg.train_dataset, batch_sampler=train_batch_sampler, **params)
test_loader = data.DataLoader(cfg.test_dataset, 1, shuffle=False, **params)
# instantiate model, optimizer and lr_scheduler
model = Config(cfg.model_path).model
if accelerator.use_distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = cfg.optimizer(cfg.param_dicts(model))
lr_scheduler = cfg.lr_scheduler(optimizer)
# register dataset class information into the model, useful for inference
cat_ids = list(range(max(cfg.train_dataset.coco.cats.keys()) + 1))
classes = tuple(cfg.train_dataset.coco.cats.get(c, {"name": "none"})["name"] for c in cat_ids)
model.register_buffer("_classes_", torch.tensor(encode_labels(classes)))
# log the configerations
logger = get_logger(os.path.basename(os.getcwd()) + "." + __name__)
# prepare for distributed training
model, optimizer, train_loader, test_loader, lr_scheduler = accelerator.prepare(
model, optimizer, train_loader, test_loader, lr_scheduler
)
if getattr(cfg, "resume_from_checkpoint", None) is not None:
if os.path.isdir(str(cfg.resume_from_checkpoint)):
accelerator.load_state(cfg.resume_from_checkpoint)
path = os.path.basename(cfg.resume_from_checkpoint)
cfg.starting_epoch = int(path.split("_")[-1]) + 1
accelerator.project_configuration.iteration = cfg.starting_epoch
logger.info(f"resume training of {cfg.output_dir}, from {path}")
elif os.path.isfile(str(cfg.resume_from_checkpoint)):
checkpoint = load_checkpoint(cfg.resume_from_checkpoint)
checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
load_state_dict(accelerator.unwrap_model(model), checkpoint)
# overwrite _classes_ in checkpoint with current datasets categories
model.register_buffer("_classes_", torch.tensor(encode_labels(classes)))
logger.info(
f"load pretrained from {cfg.resume_from_checkpoint}, output_dir is {cfg.output_dir}"
)
else:
logger.warn("resume_from_checkpoint is not a path or a file, skip loading")
else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info("model parameters: {}".format(n_params))
logger.info("optimizer: {}".format(optimizer))
logger.info("lr_scheduler: {}".format(pprint.pformat(lr_scheduler.state_dict())))
# save dataset name, useful for inference
if accelerator.is_main_process:
label_file = os.path.join(cfg.output_dir, "label_names.txt")
with open(label_file, "w") as f:
caid_name = [f"{k} {v['name']}" for k, v in cfg.train_dataset.coco.cats.items()]
caid_name = "\n".join(caid_name)
f.write(caid_name)
logger.info(f"Label names is saved to {label_file}")
logger.info("Start training")
start_time = time.perf_counter()
highest_checkpoint = HighestCheckpoint(accelerator, model)
for epoch in range(cfg.starting_epoch, cfg.num_epochs):
train_one_epoch_acc(
model=model,
optimizer=optimizer,
data_loader=train_loader,
epoch=epoch,
print_freq=cfg.print_freq,
max_grad_norm=cfg.max_norm,
accelerator=accelerator,
)
lr_scheduler.step()
# we save model and labels together
accelerator.save_state(safe_serialization=False)
logger.info("Start evaluation")
coco_evaluator = evaluate_acc(model, test_loader, epoch, accelerator)
# save best results
cur_ap, cur_ap50 = coco_evaluator.coco_eval["bbox"].stats[:2]
highest_checkpoint.update(ap=cur_ap, ap50=cur_ap50)
total_time = time.perf_counter() - start_time
total_time = str(datetime.timedelta(seconds=int(total_time)))
logger.info("Training time: {}".format(total_time))
accelerator.end_training()
if __name__ == "__main__":
train()
# 模型唯一标识
modelCode=730
# 模型名称
modelName=salience_detr_pytorch
# 模型描述
modelDescription=Salience_DETR:用层次显著性滤波细化增强检测变换器的推理和训练
# 应用场景
appScenario=训练,推理,科研,制造,医疗,家居,教育
# 框架类型
frameType=Pytorch
import inspect
import logging
import os
from typing import Dict
from omegaconf import DictConfig
from torch import nn
from util.utils import load_state_dict as _load_state_dict
class BaseBackbone:
@staticmethod
def load_state_dict(model: nn.Module, state_dict: Dict):
if state_dict is None:
return
assert isinstance(state_dict, Dict), "state_dict must be OrderedDict."
_load_state_dict(model, state_dict)
@staticmethod
def freeze_module(module: nn.Module):
module.eval()
for param in module.parameters():
param.requires_grad = False
def get_instantiate_config(self, func_name, arch, extra_params):
# log some necessary information about backbone
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
assert arch is None or arch in self.model_arch, \
f"Expected architecture in {self.model_arch.keys()} but got {arch}"
logger.info(f"Backbone architecture: {arch}")
# merge parameters from self.arch with extra params
model_config = self.model_arch[arch] if arch is not None else {}
for name, param in inspect.signature(func_name).parameters.items():
# get default, current and modified params
default = param.default if param.default is not inspect.Parameter.empty else None
modified_param = extra_params.get(name, None)
if isinstance(model_config, Dict):
cur_param = model_config.get(name, None)
elif isinstance(model_config, DictConfig):
cur_param = getattr(model_config, name, None)
else:
cur_param = None
# choose the high-prior parameter
if cur_param is not None:
default = cur_param
if modified_param is not None:
default = modified_param
# replace parameters in model_config
if isinstance(model_config, Dict):
model_config[name] = default
elif isinstance(model_config, DictConfig):
setattr(model_config, name, default)
else:
raise TypeError("Only Dict and DictConfig supported.")
return model_config
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops.stochastic_depth import StochasticDepth
from models.backbones.base_backbone import BaseBackbone
from models.bricks.misc import Conv2dNormActivation, Permute
from util.lazy_load import LazyCall as L
from util.lazy_load import instantiate
from util.utils import load_checkpoint
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
class CNBlock(nn.Module):
def __init__(
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
def forward(self, input: Tensor) -> Tensor:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
return result
class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)
class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__()
if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (
isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])
):
raise TypeError("The block_setting should be List[CNBlockConfig]")
if block is None:
block = CNBlock
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
layers: List[nn.Module] = []
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
# Downsampling
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
)
self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1),
nn.Linear(lastconv_output_channels, num_classes)
)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
class ConvNeXtBackbone(BaseBackbone):
# yapf: disable
model_weights = {
# The following weights are from torchvision
"conv_t": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
"conv_s": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
"conv_b": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
"conv_l": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
}
model_arch = {
"conv_t": L(ConvNeXt)(
block_setting=[
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
],
stochastic_depth_prob=0.1,
url=model_weights["conv_t"],
),
"conv_s": L(ConvNeXt)(
block_setting=[
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
],
stochastic_depth_prob=0.4,
url=model_weights["conv_s"],
),
"conv_b": L(ConvNeXt)(
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
],
stochastic_depth_prob=0.5,
url=model_weights["conv_b"],
),
"conv_l": L(ConvNeXt)(
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
],
stochastic_depth_prob=0.5,
url=model_weights["conv_l"],
)
}
# yapf: enable
def __new__(
self,
arch: str,
weights: Union[str, Dict] = None,
return_indices: Tuple[int] = (0, 1, 2, 3),
freeze_indices: Tuple = (),
**kwargs,
):
# get parameters and instantiate backbone
model_config = self.get_instantiate_config(self, ConvNeXt, arch, kwargs)
default_weight = model_config.pop("url", None)
convnext = instantiate(model_config)
# load state dict
weights = load_checkpoint(default_weight if weights is None else weights)
if isinstance(weights, Dict):
weights = weights["model"] if "model" in weights else weights
self.load_state_dict(convnext, weights)
# freeze stages
self._freeze_stages(self, convnext, freeze_indices)
# create feature extractor
return_layers = [f"features.{2 * idx + 1}" for idx in return_indices]
convnext = create_feature_extractor(convnext, return_layers)
convnext.num_channels = [model_config.block_setting[i].input_channels for i in return_indices]
return convnext
def _freeze_stages(self, model: nn.Module, freeze_indices: Tuple[int]):
# freeze stem
if len(freeze_indices) > 0:
self.freeze_module(model.features[0])
for idx in freeze_indices:
# freeze layers
self.freeze_module(model.features[2 * idx + 1])
# freeze downsample layers
if 2 * idx + 2 < len(model.features):
self.freeze_module(model.features[2 * idx + 2])
import os
import sys
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops import StochasticDepth
from models.backbones.base_backbone import BaseBackbone
from util.lazy_load import LazyCall as L
from util.lazy_load import instantiate
from util.utils import load_checkpoint
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = activation_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class FocalModulation(nn.Module):
""" Focal Modulation
Args:
dim (int): Number of input channels.
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
focal_level (int): Number of focal levels
focal_window (int): Focal window size at focal level 1
focal_factor (int, default=2): Step to increase the focal window
use_postln (bool, default=False): Whether use post-modulation layernorm
"""
def __init__(
self,
dim: int,
proj_drop=0.,
focal_level=2,
focal_window=7,
focal_factor=2,
use_postln_in_modulation=False,
normalize_modulator=False
):
super().__init__()
self.dim = dim
# specific args for focalv3
self.focal_level = focal_level
self.focal_window = focal_window
self.focal_factor = focal_factor
self.use_postln_in_modulation = use_postln_in_modulation
self.normalize_modulator = normalize_modulator
self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1))
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1)
self.act = nn.GELU()
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList()
if self.use_postln_in_modulation:
self.ln = nn.LayerNorm(dim)
for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append(
nn.Sequential(
nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
stride=1,
groups=dim,
padding=kernel_size // 2,
bias=False
),
nn.GELU(),
)
)
def forward(self, x):
""" Forward function.
Args:
x: input features with shape of (B, H, W, C)
"""
C = x.shape[-1]
# pre linear projection
x = self.f(x).permute(0, 3, 1, 2).contiguous()
q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)
# context aggregation
ctx_all = 0
for l in range(self.focal_level):
ctx = self.focal_layers[l](ctx)
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
# normalize context
if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation
x_out = q * self.h(ctx_all)
x_out = x_out.permute(0, 2, 3, 1).contiguous()
if self.use_postln_in_modulation:
x_out = self.ln(x_out)
# post linear projection
x_out = self.proj(x_out)
x_out = self.proj_drop(x_out)
return x_out
class FocalModulationBlock(nn.Module):
""" Focal Modulation Block.
Args:
dim (int): Number of input channels.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): number of focal levels
focal_window (int): focal kernel size at level 1
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
focal_level: int = 2,
focal_window: int = 9,
dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
use_postln: bool = False,
use_postln_in_modulation: bool = False,
normalize_modulator: bool = False,
use_layerscale: bool = False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.modulation = FocalModulation(
dim,
focal_window=focal_window,
focal_level=focal_level,
proj_drop=dropout,
use_postln_in_modulation=use_postln_in_modulation,
normalize_modulator=normalize_modulator,
)
self.drop_path = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim, hidden_features=int(dim * mlp_ratio), activation_layer=nn.GELU, drop=dropout
)
if use_layerscale:
self.gamma_1 = nn.Parameter(torch.full((dim,), 1e-4))
self.gamma_2 = nn.Parameter(torch.full((dim,), 1e-4))
else:
self.gamma_1 = self.gamma_2 = 1.0
self.forward = self._forward_post if use_postln else self._forward_pre
def _forward_post(self, x):
x = x + self.drop_path(self.gamma_1 * self.norm1(self.modulation(x)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
return x
def _forward_pre(self, x):
x = x + self.drop_path(self.gamma_1 * self.modulation(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
is_stem (bool): Is the stem block or not.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
patch_size: List[int] = (4, 4),
norm_layer: nn.Module = nn.LayerNorm,
use_conv_embed: bool = False,
is_stem: bool = False,
):
super().__init__()
self.patch_size = patch_size
if use_conv_embed:
# if we choose to use conv embedding, then we treat the stem and non-stem differently
if is_stem:
self.proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=7, stride=4, padding=2)
else:
self.proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, stride=2, padding=1)
else:
self.proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(hidden_channels)
def forward(self, input):
"""Forward function.
Args:
input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
"""
H, W, _ = input.shape[-3:]
# pad feature maps to multiples of patch size
pad_r = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
pad_b = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
input = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
# perform patch embed
input = input.permute(0, 3, 1, 2)
input = self.proj(input)
input = input.permute(0, 2, 3, 1)
input = self.norm(input)
return input
class FocalNet(nn.Module):
"""Implement paper `Focal Modulation Networks <https://arxiv.org/pdf/2203.11926.pdf>`_
Args:
patch_size (int | tuple(int)): Patch size. Default: 4.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
dropout (float): Dropout rate.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
focal_levels (Sequence[int]): Number of focal levels at four stages
focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_conv_embed (bool): Whether use overlapped convolution for patch embedding
"""
def __init__(
self,
patch_size: List[int],
embed_dim: int,
depths: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
stochastic_depth_prob: float = 0.3, # 0.3 or 0.4 works better for large+ models
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-5),
focal_levels=[3, 3, 3, 3],
focal_windows=[3, 3, 3, 3],
use_conv_embed=False,
use_postln=False,
use_postln_in_modulation=False,
use_layerscale=False,
normalize_modulator=False,
):
super().__init__()
self.num_layers = len(depths)
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
in_channels=3,
hidden_channels=embed_dim,
patch_size=patch_size,
norm_layer=norm_layer,
use_conv_embed=use_conv_embed,
is_stem=True
)
self.pos_drop = nn.Dropout(p=dropout)
self.layers = nn.ModuleList()
total_stage_blocks = sum(depths)
stage_block_id = 0
# build FocalNet blocks
for i_stage in range(len(depths)):
blocks: List[nn.Module] = []
dim = embed_dim * 2**i_stage
for _ in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
blocks.append(
FocalModulationBlock(
dim,
mlp_ratio=mlp_ratio,
focal_level=focal_levels[i_stage],
focal_window=focal_windows[i_stage],
dropout=dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
use_postln=use_postln,
use_postln_in_modulation=use_postln_in_modulation,
normalize_modulator=normalize_modulator,
use_layerscale=use_layerscale,
)
)
stage_block_id += 1
stage = OrderedDict(blocks=nn.Sequential(*blocks))
# add patch merging layer
if i_stage < (len(depths) - 1):
stage["downsample"] = PatchEmbed(dim, int(2 * dim), (2, 2), norm_layer, use_conv_embed)
self.layers.append(nn.Sequential(stage))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# (b, c, h, w) -> (b, h, w, c)
x = x.permute(0, 2, 3, 1)
x = self.pos_drop(self.patch_embed(x))
for i in range(self.num_layers):
x = self.layers[i](x)
return x
class PostProcess(nn.Module):
def __init__(
self,
in_channels: int,
return_indices: Tuple[int] = (0, 1, 2, 3),
norm_layer: nn.Module = nn.LayerNorm,
):
super().__init__()
self.return_indices = return_indices
for channel, idx in zip(in_channels, return_indices):
self.add_module(f"norm{idx}", norm_layer(channel))
def forward(self, multi_level_feats: Dict[str, torch.Tensor]):
for idx, (key, value) in zip(self.return_indices, multi_level_feats.items()):
feat = getattr(self, f"norm{idx}")(value).permute(0, 3, 1, 2).contiguous()
multi_level_feats[key] = feat
return multi_level_feats
class FocalNetBackbone(BaseBackbone):
model_weights = {
# The following weights are from original repository
"focalnet_tiny_srf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_tiny_srf.pth",
"focalnet_tiny_lrf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_tiny_lrf.pth",
"focalnet_small_srf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_small_srf.pth",
"focalnet_small_lrf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_small_lrf.pth",
"focalnet_base_srf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_base_srf.pth",
"focalnet_base_lrf":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_base_lrf.pth",
"focalnet_large_lrf_384":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_large_lrf_384.pth",
"focalnet_large_lrf_384_fl4":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_large_lrf_384_fl4.pth",
"focalnet_xlarge_lrf_384":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_xlarge_lrf_384.pth",
"focalnet_xlarge_lrf_384_fl4":
"https://github.com/microsoft/FocalNet/releases/download/v1.0.0/focalnet_xlarge_lrf_384_fl4.pth",
# The following weights are from huggingface
"focalnet_large_fl4_dino_o365":
"https://huggingface.co/microsoft/focalnet-large-fl4-dino-o365/resolve/main/focalnet_large_fl4_pretrained_on_o365.pth",
"focalnet_large_fl4_dino_o365_coco":
"https://huggingface.co/microsoft/focalnet-large-fl4-dino-o365-cocoft/resolve/main/focalnet_large_fl4_o365_finetuned_on_coco.pth",
}
# yapf: disable
model_arch = {
"focalnet_tiny_srf": L(FocalNet)(
embed_dim=96,
patch_size=(4, 4),
depths=(2, 2, 6, 2),
stochastic_depth_prob=0.2,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_tiny_srf"],
),
"focalnet_tiny_lrf": L(FocalNet)(
embed_dim=96,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.2,
focal_levels=(3, 3, 3, 3),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_tiny_lrf"],
),
"focalnet_small_srf": L(FocalNet)(
embed_dim=96,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.3,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_small_srf"],
),
"focalnet_small_lrf": L(FocalNet)(
embed_dim=96,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.3,
focal_levels=(3, 3, 3, 3),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_small_lrf"],
),
"focalnet_base_srf": L(FocalNet)(
embed_dim=128,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_base_srf"],
),
"focalnet_base_lrf": L(FocalNet)(
embed_dim=128,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(3, 3, 3, 3),
focal_windows=(3, 3, 3, 3),
url=model_weights["focalnet_base_lrf"],
),
"focalnet_large_lrf": L(FocalNet)(
embed_dim=192,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(3, 3, 3, 3),
focal_windows=(5, 5, 5, 5),
use_conv_embed=True,
use_postln=True,
use_postln_in_modulation=False,
use_layerscale=True,
normalize_modulator=False,
url=model_weights["focalnet_large_lrf_384"],
),
"focalnet_large_lrf_fl4": L(FocalNet)(
embed_dim=192,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(4, 4, 4, 4),
focal_windows=(3, 3, 3, 3),
use_conv_embed=True,
use_postln=True,
use_postln_in_modulation=False,
use_layerscale=True,
normalize_modulator=True,
url=model_weights["focalnet_large_lrf_384_fl4"],
),
"focalnet_xlarge_lrf": L(FocalNet)(
embed_dim=256,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(3, 3, 3, 3),
focal_windows=(5, 5, 5, 5),
use_conv_embed=True,
use_postln=True,
use_postln_in_modulation=False,
use_layerscale=True,
normalize_modulator=False,
url=model_weights["focalnet_xlarge_lrf_384"],
),
"focalnet_xlarge_lrf_fl4": L(FocalNet)(
embed_dim=256,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(4, 4, 4, 4),
focal_windows=(3, 3, 3, 3),
use_conv_embed=True,
use_postln=True,
use_postln_in_modulation=False,
use_layerscale=True,
normalize_modulator=True,
url=model_weights["focalnet_xlarge_lrf_384_fl4"],
),
"focalnet_huge_fl3": L(FocalNet)(
embed_dim=352,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(3, 3, 3, 3),
focal_windows=(3, 3, 3, 3),
use_conv_embed=True,
use_postln=True,
use_layerscale=True,
use_postln_in_modulation=True,
normalize_modulator=False,
),
"focalnet_huge_fl4": L(FocalNet)(
embed_dim=352,
patch_size=(4, 4),
depths=(2, 2, 18, 2),
stochastic_depth_prob=0.5,
focal_levels=(4, 4, 4, 4),
focal_windows=(3, 3, 3, 3),
use_conv_embed=True,
use_postln=True,
use_postln_in_modulation=True,
use_layerscale=True,
normalize_modulator=False,
),
}
# yapf: enable
def __new__(
self,
arch: str,
weights: Dict = None,
return_indices: Tuple[int] = (0, 1, 2, 3),
freeze_indices: Tuple = (),
**kwargs,
):
# get parameters and instantiate backbone
model_config = self.get_instantiate_config(self, FocalNet, arch, kwargs)
default_weight = model_config.pop("url", None)
focalnet = instantiate(model_config)
# load state dict
weights = load_checkpoint(default_weight if weights is None else weights)
if isinstance(weights, Dict):
weights = weights["model"] if "model" in weights else weights
self.load_state_dict(focalnet, weights)
# freeze stages
self._freeze_stages(self, focalnet, freeze_indices)
# create feature extractor
return_layers = [f"layers.{idx}.blocks" for idx in return_indices]
focalnet = create_feature_extractor(focalnet, return_layers)
focalnet.num_channels = [model_config.embed_dim * 2**idx for idx in return_indices]
# add post_process for feature extractor
post_process = PostProcess(focalnet.num_channels, return_indices, model_config.norm_layer)
backbone = nn.Sequential(focalnet, post_process)
backbone.num_channels = focalnet.num_channels
return backbone
def _freeze_stages(self, model: nn.Module, freeze_indices: Tuple[int]):
# freeze patch embed
if len(freeze_indices) > 0:
self.freeze_module(model.patch_embed)
# freeze layers
for idx in freeze_indices:
self.freeze_module(model.layers[idx])
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import Tensor, nn
from torchvision.models.feature_extraction import create_feature_extractor
from models.bricks.misc import FrozenBatchNorm2d
from models.backbones.base_backbone import BaseBackbone
from models.bricks.deform_conv2d_pack import DeformConv2dPack
from util.lazy_load import LazyCall as L
from util.lazy_load import instantiate
from util.utils import load_checkpoint
def conv3x3(
in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1
) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv3x3_dcn(
in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1
) -> DeformConv2dPack:
"""3x3 deformable convolution with padding"""
return DeformConv2dPack(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
with_dcn: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
if with_dcn:
self.conv2 = conv3x3_dcn(planes, planes)
else:
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
with_dcn: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
if with_dcn:
self.conv2 = conv3x3_dcn(width, width, stride, groups, dilation)
else:
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
stage_with_dcn: Optional[List[bool]] = None, # we only add an extra parameter
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if stage_with_dcn is None:
stage_with_dcn = [False] * 4
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], with_dcn=stage_with_dcn[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
with_dcn=stage_with_dcn[1],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
with_dcn=stage_with_dcn[2],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
with_dcn=stage_with_dcn[3],
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
with_dcn: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
self.groups,
self.base_width,
previous_dilation,
norm_layer,
with_dcn,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
with_dcn=with_dcn,
)
)
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
class ResNetBackbone(BaseBackbone):
# yapf: disable
model_weights = {
# The following weights are from torchvision
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50_v1": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet50_v2": "https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
"resnet101_v1": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet101_v2": "https://download.pytorch.org/models/resnet101-cd907fc2.pth",
"resnet152_v1": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnet152_v2": "https://download.pytorch.org/models/resnet152-f82ba261.pth",
"resnext50_32x4d_v1": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext50_32x4d_v2": "https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
"resnext101_32x8d_v1": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"resnext101_32x8d_v2": "https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
"resnext101_64x4d": "https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
"wide_resnet50_2_v1": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet50_2_v2": "https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
"wide_resnet101_2_v1": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"wide_resnet101_2_v2": "https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
# The following weights are transfomed from mmpretrain
"resnext101_32x4d":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.1-beta/resnext101_32x4d-e0fa3dd5.pth",
}
model_arch = {
"resnet18": L(ResNet)(block=BasicBlock, layers=(2, 2, 2, 2), url=model_weights["resnet18"]),
"resnet34": L(ResNet)(block=BasicBlock, layers=(3, 4, 6, 3), url=model_weights["resnet34"]),
"resnet50": L(ResNet)(block=Bottleneck, layers=(3, 4, 6, 3), url=model_weights["resnet50_v2"]),
"resnet101": L(ResNet)(block=Bottleneck, layers=(3, 4, 23, 3), url=model_weights["resnet101_v2"]),
"resnet152": L(ResNet)(block=Bottleneck, layers=(3, 8, 36, 3), url=model_weights["resnet152_v2"]),
"resnext50_32x4d": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 6, 3),
groups=32,
width_per_group=4,
url=model_weights["resnext50_32x4d_v2"],
),
"resnext101_32x4d": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 23, 3),
groups=32,
width_per_group=4,
url=model_weights["resnext101_32x4d"],
),
"resnext101_32x8d": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 23, 3),
groups=32,
width_per_group=8,
url=model_weights["resnext101_32x8d_v2"],
),
"resnext101_64x4d": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 23, 3),
groups=64,
width_per_group=4,
url=model_weights["resnext101_64x4d"],
),
"wide_resnet50_2": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 6, 3),
width_per_group=64 * 2,
url=model_weights["wide_resnet50_2_v2"],
),
"wide_resnet101_2": L(ResNet)(
block=Bottleneck,
layers=(3, 4, 23, 3),
width_per_group=64 * 2,
url=model_weights["wide_resnet101_2_v2"],
),
}
# yapf: enable
def __new__(
self,
arch: str,
weights: Dict = None,
return_indices: Tuple[int] = (0, 1, 2, 3),
freeze_indices: Tuple = (),
**kwargs,
):
# get parameters and instantiate backbone
model_config = self.get_instantiate_config(self, ResNet, arch, kwargs)
default_weight = model_config.pop("url", None)
resnet = instantiate(model_config)
# load state dict
weights = load_checkpoint(default_weight if weights is None else weights)
if isinstance(weights, Dict):
weights = weights["model"] if "model" in weights else weights
self.load_state_dict(resnet, weights)
# freeze stages
self._freeze_stages(self, resnet, freeze_indices)
# create feature extractor
return_layers = [f"layer{idx + 1}" for idx in return_indices]
resnet = create_feature_extractor(
resnet, return_layers, tracer_kwargs={"leaf_modules": [FrozenBatchNorm2d]}
)
resnet.num_channels = [64 * model_config.block.expansion * 2**idx for idx in return_indices]
return resnet
def _freeze_stages(self, model: nn.Module, freeze_indices: Tuple[int]):
# freeze stem
if len(freeze_indices) > 0:
self.freeze_module(model.conv1)
self.freeze_module(model.bn1)
# freeze layers
for idx in freeze_indices:
self.freeze_module(model.get_submodule(f"layer{idx+1}"))
import math
import os
import sys
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from omegaconf import OmegaConf
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops import MLP, Permute, StochasticDepth
from models.backbones.base_backbone import BaseBackbone
from util.lazy_load import LazyCall as L
from util.lazy_load import instantiate
from util.utils import load_checkpoint
def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
return x
torch.fx.wrap("_patch_merging_pad")
def _get_relative_position_bias(
relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> torch.Tensor:
N = window_size[0] * window_size[1]
relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias
torch.fx.wrap("_get_relative_position_bias")
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
x = _patch_merging_pad(x)
x = self.norm(x)
x = self.reduction(x) # ... H/2 W/2 2*C
return x
class PatchMergingV2(nn.Module):
"""Patch Merging Layer for Swin Transformer V2.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim) # difference
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
x = _patch_merging_pad(x)
x = self.reduction(x) # ... H/2 W/2 2*C
x = self.norm(x)
return x
def shifted_window_attention(
input: Tensor,
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: List[int],
num_heads: int,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
training: bool = True,
) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): Window size.
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
B, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape
shift_size = shift_size.copy()
# If window size is larger than feature size, there is no need to shift window
if window_size[0] >= pad_H:
shift_size[0] = 0
if window_size[1] >= pad_W:
shift_size[1] = 0
# cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
# partition windows
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4,
5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
# multi-head attention
if logit_scale is not None and qkv_bias is not None:
qkv_bias = qkv_bias.clone()
length = qkv_bias.numel() // 3
qkv_bias[length:2 * length].zero_()
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if logit_scale is not None:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
attn = attn * logit_scale
else:
q = q * (C // num_heads)**-0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
if sum(shift_size) > 0:
# generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W))
h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
count = 0
for h in h_slices:
for w in w_slices:
attn_mask[h[0]:h[1], w[0]:w[1]] = count
count += 1
attn_mask = attn_mask.view(
pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]
)
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout, training=training)
# reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
# reverse cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
# unpad features
x = x[:, :H, :W, :].contiguous()
return x
torch.fx.wrap("shifted_window_attention")
class ShiftedWindowAttention(nn.Module):
"""
See :func:`shifted_window_attention`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
if len(window_size) != 2 or len(shift_size) != 2:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.define_relative_position_bias_table()
self.define_relative_position_index()
def define_relative_position_bias_table(self):
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def define_relative_position_index(self):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
def get_relative_position_bias(self) -> torch.Tensor:
return _get_relative_position_bias(
self.relative_position_bias_table,
self.relative_position_index,
self.window_size # type: ignore[arg-type]
)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)
class ShiftedWindowAttentionV2(ShiftedWindowAttention):
"""
See :func:`shifted_window_attention_v2`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__(
dim,
window_size,
shift_size,
num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
)
if qkv_bias:
length = self.qkv.bias.numel() // 3
self.qkv.bias[length:2 * length].data.zero_()
def define_relative_position_bias_table(self):
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")
)
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(
0
) # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
)
self.register_buffer("relative_coords_table", relative_coords_table)
def get_relative_position_bias(self) -> torch.Tensor:
relative_position_bias = _get_relative_position_bias(
self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
self.relative_position_index, # type: ignore[arg-type]
self.window_size,
)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
return relative_position_bias
def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
relative_position_bias = self.get_relative_position_bias()
return shifted_window_attention(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
self.num_heads,
shift_size=self.shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
logit_scale=self.logit_scale,
training=self.training,
)
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
window_size,
shift_size,
num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = MLP(
dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout
)
for m in self.mlp.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
return x
class SwinTransformerBlockV2(SwinTransformerBlock):
"""
Swin Transformer V2 Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
):
super().__init__(
dim,
num_heads,
window_size,
shift_size,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=stochastic_depth_prob,
norm_layer=norm_layer,
attn_layer=attn_layer,
)
def forward(self, x: Tensor):
# Here is the difference, we apply norm after the attention in V2.
# In V1 we applied norm before the attention.
x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x
class SwinTransformer(nn.Module):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/abs/2103.14030>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 1000.
block (nn.Module, optional): SwinTransformer Block. Default: None.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
"""
def __init__(
self,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.1,
num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None,
downsample_layer: Callable[..., nn.Module] = PatchMerging,
):
super().__init__()
self.num_classes = num_classes
if block is None:
block = SwinTransformerBlock
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
layers: List[nn.Module] = []
# split image into non-overlapping patches
layers.append(
nn.Sequential(
nn.Conv2d(
3,
embed_dim,
kernel_size=(patch_size[0], patch_size[1]),
stride=(patch_size[0], patch_size[1])
),
Permute([0, 2, 3, 1]),
norm_layer(embed_dim),
)
)
total_stage_blocks = sum(depths)
stage_block_id = 0
# build SwinTransformer blocks
for i_stage in range(len(depths)):
stage: List[nn.Module] = []
dim = embed_dim * 2**i_stage
for i_layer in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
stage.append(
block(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
# add patch merging layer
if i_stage < (len(depths) - 1):
layers.append(downsample_layer(dim, norm_layer))
self.features = nn.Sequential(*layers)
num_features = embed_dim * 2**(len(depths) - 1)
self.norm = norm_layer(num_features)
self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten(1)
self.head = nn.Linear(num_features, num_classes)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = self.norm(x)
x = self.permute(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.head(x)
return x
class PostProcess(nn.Module):
def forward(self, multi_level_feats: Dict[str, Tensor]):
return {k: v.permute(0, 3, 1, 2) for k, v in multi_level_feats.items()}
class SwinTransformerBackbone(BaseBackbone):
# yapf: disable
model_weights = {
# The following weights are from torchvision
"swin_t": "https://download.pytorch.org/models/swin_t-704ceda3.pth",
"swin_s": "https://download.pytorch.org/models/swin_s-5e29d889.pth",
"swin_b": "https://download.pytorch.org/models/swin_b-68c6b09e.pth",
"swin_v2_t": "https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",
"swin_v2_b": "https://download.pytorch.org/models/swin_v2_b-781e5279.pth",
# The following weights are convert from original repo
# Swin_T
"swin_t_in1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_tiny_patch4_window7_224.pth",
"swin_t_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_tiny_patch4_window7_224_22k.pth",
"swin_t_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_tiny_patch4_window7_224_22kto1k.pth",
# Swin_S
"swin_s_in1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_small_patch4_window7_224.pth",
"swin_s_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_small_patch4_window7_224_22k.pth",
"swin_s_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_small_patch4_window7_224_22kto1k_finetune.pth",
# Swin_B
"swin_b_in1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_base_patch4_window7_224.pth",
"swin_b_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_base_patch4_window7_224_22k.pth",
"swin_b_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_base_patch4_window7_224_22kto1k.pth",
# Swin_B_384
"swin_b_384_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_base_patch4_window12_384_22k.pth",
"swin_b_384_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_base_patch4_window12_384_22kto1k.pth",
# Swin_L
"swin_l_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_large_patch4_window7_224_22k.pth",
"swin_l_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_large_patch4_window7_224_22kto1k.pth",
# Swin_L_384
"swin_l_384_in22k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_large_patch4_window12_384_22k.pth",
"swin_l_384_in22kto1k":
"https://github.com/xiuqhou/pretrained_weights/releases/download/v1.0.2-beta/swin_large_patch4_window12_384_22kto1k.pth",
}
model_arch = {
"swin_t": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=(7, 7),
stochastic_depth_prob=0.2,
url=model_weights["swin_t"],
),
"swin_s": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
window_size=(7, 7),
stochastic_depth_prob=0.3,
url=model_weights["swin_s"],
),
"swin_b": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
window_size=(7, 7),
stochastic_depth_prob=0.5,
url=model_weights["swin_b"],
),
"swin_l": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
window_size=(7, 7),
stochastic_depth_prob=0.2,
url=model_weights["swin_l_in22k"],
),
"swin_b_384": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
window_size=(12, 12),
stochastic_depth_prob=0.2,
url=model_weights["swin_b_384_in22k"],
),
"swin_l_384": L(SwinTransformer)(
patch_size=(4, 4),
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
window_size=(12, 12),
stochastic_depth_prob=0.2,
url=model_weights["swin_l_384_in22k"],
),
"swin_v2_t": L(SwinTransformer)(
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 8],
stochastic_depth_prob=0.2,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
url=model_weights["swin_v2_t"],
),
"swin_v2_b": L(SwinTransformer)(
patch_size=[4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 8],
stochastic_depth_prob=0.5,
block=SwinTransformerBlockV2,
downsample_layer=PatchMergingV2,
url=model_weights["swin_v2_b"],
),
}
# yapf: enable
def __new__(
self,
arch: str,
weights: Union[str, Dict] = None,
return_indices: Tuple[int] = (0, 1, 2, 3),
freeze_indices: Tuple[int] = (),
**kwargs
):
# get parameters and instantiate backbone
model_config = self.get_instantiate_config(self, SwinTransformer, arch, kwargs)
default_weight = model_config.pop("url", None)
# omegaconf automatically convert native to MutableMapping
# which may leads type check error during tracing.
# Convert it back to python native mapping type.
swin_transformer = instantiate(OmegaConf.to_object(model_config))
# load state dict
weights = load_checkpoint(default_weight if weights is None else weights)
if isinstance(weights, Dict):
weights = weights["model"] if "model" in weights else weights
self.load_state_dict(swin_transformer, weights)
# freeze stages
self._freeze_stages(self, swin_transformer, freeze_indices)
# create feature extractor
return_layers = [f"features.{2 * idx + 1}" for idx in return_indices]
swin_transformer = create_feature_extractor(swin_transformer, return_layers)
swin_transformer.num_channels = [model_config.embed_dim * 2**idx for idx in return_indices]
# add post_process for swin_transformer output
backbone = nn.Sequential(swin_transformer, PostProcess())
backbone.num_channels = swin_transformer.num_channels
return backbone
def _freeze_stages(self, model: nn.Module, freeze_indices: Tuple[int]):
if len(freeze_indices) > 0:
self.freeze_module(model.features[0])
for idx in freeze_indices:
# freeze layers
self.freeze_module(model.features[2 * idx + 1])
# freeze downsample layers
if 2 * idx + 2 < len(model.features):
self.freeze_module(model.features[2 * idx + 2])
import math
import os
import sys
from collections import OrderedDict
from functools import partial
from math import pi
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.models.convnext import LayerNorm2d
from torchvision.models.vision_transformer import ConvStemConfig, MLPBlock
from torchvision.ops.stochastic_depth import StochasticDepth
from models.backbones.base_backbone import BaseBackbone
from models.bricks.misc import Conv2dNormActivation
from util.lazy_load import LazyCall as L
from util.lazy_load import instantiate
from util.utils import load_checkpoint, load_state_dict
try:
import xformers.ops as xops
HAS_XFORMER = True
except:
HAS_XFORMER = False
def window_partition(x, window_size):
"""Partition into non-overlapping windows with padding if needed.
:param x: input tokens with [B, H, W, C].
:param window_size: window size.
:return: windows and padded height and width
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows, window_size, pad_hw, hw):
"""Window unpartition into original sequences and removing padding.
:param windows: input tokens with [B * num_windows, window_size, window_size, C].
:param window_size: window size.
:param pad_hw: padded height and width (Hp, Wp).
:param hw: original height and width (H, W) before padding.
:return: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rotate_half(x):
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return x.view(*x.shape[:-2], x.shape[-2] * x.shape[-1])
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len=16,
ft_seq_len=None,
custom_freqs=None,
freqs_for='lang',
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f'unknown modality {freqs_for}')
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
t = t.expand(ft_seq_len, -1)
t = torch.stack([t.T, t], -1)
freqs = t.unsqueeze(-1) * freqs
freqs = freqs.repeat_interleave(2, -1).view(ft_seq_len, ft_seq_len, -1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
def forward(self, t):
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
class SwiGLU(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int = None,
out_channels: int = None,
norm_layer: nn.Module = nn.LayerNorm,
activation_layer: nn.Module = nn.SiLU,
bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.w1 = nn.Linear(in_channels, hidden_channels, bias=bias)
self.w2 = nn.Linear(in_channels, hidden_channels, bias=bias)
self.act = activation_layer()
self.ffn_ln = nn.Identity() if norm_layer is None else norm_layer(hidden_channels)
self.w3 = nn.Linear(hidden_channels, out_channels, bias=bias)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
qk_scale=None,
attn_head_dim=None,
rope=None,
xattn=True,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.rope = rope
self.xattn = xattn
self.proj = nn.Linear(all_head_dim, dim)
if not HAS_XFORMER:
self.xattn = False
def forward(self, x):
B, H, W, C = x.shape
x = x.view(B, -1, C)
N = H * W
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
## rope
q = self.rope(q).type_as(v)
k = self.rope(k).type_as(v)
if self.xattn:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = xops.memory_efficient_attention(q, k, v)
x = x.reshape(B, N, -1)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1).type_as(x)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = x.view(B, H, W, C)
return x
class ResBottleneckBlock(nn.Module):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
inplace=True,
norm_layer=nn.LayerNorm,
activation_layer=nn.GELU,
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__(in_channels, out_channels, 1)
self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = norm_layer(bottleneck_channels)
self.act1 = activation_layer(inplace=inplace)
self.conv2 = nn.Conv2d(
bottleneck_channels,
bottleneck_channels,
3,
padding=1,
bias=False,
)
self.norm2 = norm_layer(bottleneck_channels)
self.act2 = activation_layer(inplace=inplace)
self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = norm_layer(out_channels)
for layer in [self.conv1, self.conv2, self.conv3]:
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
def forward(self, x):
out = x
for layer in self.children():
out = layer(out)
out = x + out
return out
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
drop_path: float = 0.0,
rope: nn.Module = None,
use_swiglu: bool = False,
window_size: int = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
# NOTE: Different from pytorch ViT, may use rope attention
self.use_rope = rope is not None
if rope is not None:
self.self_attention = Attention(hidden_dim, num_heads, rope=rope)
else:
self.self_attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
)
self.dropout = nn.Dropout(dropout)
# NOTE: Different from pytorch ViT, we add StochasticDepth
self.stochastic_depth = StochasticDepth(drop_path, mode="row")
# NOTE: Different from pytorch ViT, may use SwiGLU as MLPBlock
self.ln_2 = norm_layer(hidden_dim)
if use_swiglu:
self.mlp = SwiGLU(hidden_dim, mlp_dim, norm_layer=norm_layer)
else:
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
self.window_size = window_size
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
if self.use_rope:
# NOTE: for backbone variants of EVA02, remove batch_class_token
batch_class_token, x = x[:, :1, :], x[:, 1:, :]
b, n, d = x.shape
n_h = n_w = int(n**0.5)
assert n == n_h * n_w, "height and width of image must be equal"
x = x.view(b, n_h, n_w, d)
# window partition following EVA02
if self.window_size > 0:
x, pad_hw = window_partition(x, self.window_size)
x = self.self_attention(x)
# NOTE: reverse window partition, following EVA02
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (n_h, n_w))
x = x.view(b, n, d)
x = torch.cat([batch_class_token, x], dim=1)
else:
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.stochastic_depth(self.dropout(x))
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
y = self.stochastic_depth(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation"""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
drop_path_rate: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
use_rope: bool = False,
use_swiglu: bool = False,
window_block_indexes: Tuple[int] = (),
patch_size: int = None,
image_size: int = None,
window_size: int = 0,
):
super().__init__()
# NOTE: EVA02 use different pos_embed
self.use_rope = use_rope
if self.use_rope:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (224 // patch_size) * (224 // patch_size)
num_positions = num_patches + 1
self.pos_embedding = nn.Parameter(torch.zeros(1, num_positions, hidden_dim))
nn.init.trunc_normal_(self.pos_embedding, std=0.02)
else:
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(
torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)
) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
# rope
if use_rope:
assert image_size is not None and patch_size is not None, """
image_size and patch_size cannot be None when using repo
"""
self.rope_win = VisionRotaryEmbeddingFast(
dim=hidden_dim // num_heads // 2,
pt_seq_len=patch_size,
ft_seq_len=window_size,
)
self.rope_glb = VisionRotaryEmbeddingFast(
dim=hidden_dim // num_heads // 2,
pt_seq_len=patch_size,
ft_seq_len=image_size // patch_size,
)
# NOTE: stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
for i in range(num_layers):
if use_rope:
rope = self.rope_win if i in window_block_indexes else self.rope_glb
else:
rope = None
cur_window_size = window_size if use_rope and i in window_block_indexes else 0
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
drop_path=dpr[i],
rope=rope,
use_swiglu=use_swiglu,
window_size=cur_window_size,
norm_layer=norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
if self.use_rope:
# remove batch_class_token
cls_embedding, pos_embedding = self.pos_embedding[:, :1], self.pos_embedding[:, 1:]
patch_size = int(pos_embedding.shape[1]**0.5)
assert patch_size * patch_size == pos_embedding.shape[1]
image_size = int((input.shape[1] - 1)**0.5)
assert image_size * image_size + 1 == input.shape[1]
if patch_size != image_size:
pos_embedding = pos_embedding.view(1, patch_size, patch_size, -1)
pos_embedding = F.interpolate(
pos_embedding.permute(0, 3, 1, 2),
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
)
pos_embedding = pos_embedding.permute(0, 2, 3, 1)
pos_embedding = pos_embedding.view(1, image_size * image_size, -1)
pos_embedding = torch.cat([cls_embedding, pos_embedding], 1)
else:
pos_embedding = self.pos_embedding
input = input + pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""This module implements Vision Transformer as per https://arxiv.org/abs/2010.11929.
and Vision Transformer (ViT) backbone as per :paper:`vitdet`. Exploring Plain Vision
Transformer Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
"""
def __init__(
self,
image_size: int,
patch_size: int = 16,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
# following EVA-02
drop_path_rate: float = 0.0,
use_rope: bool = False,
use_swiglu: bool = False,
window_size: int = 0,
window_block_indexes: Tuple = (),
):
super().__init__()
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.drop_path_rate = drop_path_rate
self.norm_layer = norm_layer
if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
seq_proj = nn.Sequential()
prev_channels = 3
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
Conv2dNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)
seq_length = (image_size // patch_size)**2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
drop_path_rate,
norm_layer,
use_rope,
use_swiglu,
window_block_indexes,
patch_size,
image_size,
window_size,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
if self.conv_proj.bias is not None:
nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight,
mean=0.0,
std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
if isinstance(self.heads.head, nn.Linear):
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x
def forward(self, x):
# Reshape and permute the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
class VisionTransformerNoHead(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self.heads
def _process_input(self, x: Tensor) -> Tensor:
h, w = x.shape[-2:]
torch._assert(
h <= self.image_size, f"Image height must be smaller than {self.image_size} but got {h}!"
)
torch._assert(
w <= self.image_size, f"Image width must be smaller than {self.image_size} but got {w}!"
)
x = F.pad(x, (0, self.image_size - w, 0, self.image_size - h), value=0)
n, _, h, w = x.shape
n_h = h // self.patch_size
n_w = w // self.patch_size
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x
def forward(self, x):
h, w = x.shape[-2:]
# Reshape and permute the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# remove class token
x = x[:, 1:]
num_seq = x.shape[1]
size = int(num_seq**0.5)
assert size * size == num_seq
x = x.view(n, size, size, -1)
n_h = h // self.patch_size
n_w = w // self.patch_size
x = x[:, :n_h, :n_w, :].contiguous().permute(0, 3, 1, 2)
# (b, c, h, w)
return x
class SimpleFeaturePyramid(nn.Module):
def __init__(self, in_channels, out_channels, scale_factors, extra_block=False, norm_layer=LayerNorm2d):
super(SimpleFeaturePyramid, self).__init__()
self.scale_factors = scale_factors
for scale in scale_factors:
out_dim = in_channels
if scale == 4.0:
layers = [
nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2),
norm_layer(in_channels // 2),
nn.GELU(),
nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2),
]
out_dim = in_channels // 4
elif scale == 2.0:
layers = [nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)]
out_dim = in_channels // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
layers.extend([
Conv2dNormActivation(out_dim, out_channels, kernel_size=1, norm_layer=norm_layer),
Conv2dNormActivation(
out_channels, out_channels, kernel_size=3, padding=1, norm_layer=norm_layer
),
])
layers = nn.Sequential(*layers)
stage = 4 - int(math.log2(scale))
self.add_module(f"stage_{stage}", layers)
self.in_channels = in_channels
self.extra_block = extra_block
def forward(self, x):
results = {}
for stage_index in range(6):
cur_stage = getattr(self, f"stage_{stage_index}", None)
if cur_stage is None:
continue
cur_stage_feature = cur_stage(x)
# layer1, layer2, layer3, layer4
# p2, p3, p4, p5
results[f"layer{stage_index - 1}"] = cur_stage_feature
last_stage_index = stage_index
if self.extra_block:
extra_feature = F.max_pool2d(cur_stage_feature, kernel_size=1, stride=2, padding=0)
results[f"layer{last_stage_index}"] = extra_feature
return results
class VisionTransformerBackbone(BaseBackbone):
model_weights = {
# The following weights are from torchvision
"vit_b_16":
"https://download.pytorch.org/models/vit_b_16-c867db91.pth",
"vit_b_16_swag_e2e_v1":
"https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
"vit_b_16_swag_linear_v1":
"https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
"vit_b_32":
"https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
"vit_l_16":
"https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
"vit_l_16_swag_e2e_v1":
"https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
"vit_l_16_swag_linear_v1":
"https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
"vit_l_32":
"https://download.pytorch.org/models/vit_l_32-c7638314.pth",
"vit_h_14_swag_e2e_v1":
"https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
"vit_h_14_swag_linear_v1":
"https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
}
model_arch = {
"vit_b_16": L(VisionTransformerNoHead)(
image_size=224,
mlp_dim=3072,
url=model_weights["vit_b_16"],
),
"vit_b_32": L(VisionTransformerNoHead)(
image_size=224,
patch_size=32,
mlp_dim=3072,
url=model_weights["vit_b_32"],
),
"vit_l_16": L(VisionTransformerNoHead)(
image_size=224,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
url=model_weights["vit_l_16"],
),
"vit_l_32": L(VisionTransformerNoHead)(
image_size=224,
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
url=model_weights["vit_l_32"],
),
"vit_h_14": L(VisionTransformerNoHead)(
image_size=224,
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
url=model_weights["vit_h_14_swag_e2e_v1"],
),
"eva_02_vit_b_4attn_1024": L(VisionTransformerNoHead)(
image_size=1024,
hidden_dim=768,
mlp_dim=2048,
drop_path_rate=0.1,
use_rope=True,
use_swiglu=True,
window_size=16,
window_block_indexes=(0, 1, 3, 4, 6, 7, 9, 10),
),
"eva_02_vit_b_6attn_win32_1536": L(VisionTransformerNoHead)(
image_size=1536,
hidden_dim=768,
mlp_dim=2048,
drop_path_rate=0.1,
use_rope=True,
use_swiglu=True,
window_size=32,
window_block_indexes=(0, 2, 4, 6, 8, 10),
),
"eva_02_vit_l_4attn_1024": L(VisionTransformerNoHead)(
image_size=1024,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=2730,
drop_path_rate=0.4,
use_rope=True,
use_swiglu=True,
window_size=16,
window_block_indexes=[
0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, \
13, 14, 15, 16, 18, 19, 20, 21, 22
]
),
"eva_02_vit_l_8attn_1536": L(VisionTransformerNoHead)(
image_size=1536,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=2730,
drop_path_rate=0.3,
use_rope=True,
use_swiglu=True,
window_size=16,
window_block_indexes=[
0, 1, 3, 4, 6, 7, 9, 10, 12, \
13, 15, 16, 18, 19, 21, 22
]
),
"eva_02_vit_l_8attn_win32_1536": L(VisionTransformerNoHead)(
image_size=1536,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=2730,
drop_path_rate=0.4,
use_rope=True,
use_swiglu=True,
window_size=32,
window_block_indexes=[
0, 1, 3, 4, 6, 7, 9, 10, 12, \
13, 15, 16, 18, 19, 21, 22
],
)
}
def __new__(
self,
arch: str,
weights: Dict = None,
return_indices: Tuple[int] = (0, 1, 2, 3),
**kwargs,
):
# get parameters and instantiate backbone
model_config = self.get_instantiate_config(self, VisionTransformer, arch, kwargs)
default_weight = model_config.pop("url", None)
if "image_size" in model_config and "patch_size" in model_config:
divise = model_config["image_size"] / model_config["patch_size"]
model_config["image_size"] = math.ceil(divise) * model_config["patch_size"]
vit = instantiate(model_config)
# load state dict
weights = load_checkpoint(default_weight if weights is None else weights)
if isinstance(weights, Dict):
weights = weights["model"] if "model" in weights else weights
self.load_state_dict(vit, weights)
scale_factors = [2**(2 - key) for key in return_indices]
fpn = SimpleFeaturePyramid(
vit.hidden_dim,
256,
scale_factors=scale_factors,
extra_block=4 in return_indices,
)
backbone = nn.Sequential(vit, fpn)
backbone.num_channels = [256] * len(return_indices)
return backbone
import torch
import torchvision
from torch import nn
class DETRBaseTransformer(nn.Module):
"""A base class that contains some methods commonly used in DETR transformer,
such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer.
"""
def __init__(self, num_feature_levels, embed_dim):
super().__init__()
self.embed_dim = embed_dim
self.num_feature_levels = num_feature_levels
self.level_embeds = nn.Parameter(torch.Tensor(num_feature_levels, embed_dim))
self._init_weights_detr_transformer()
def _init_weights_detr_transformer(self):
nn.init.normal_(self.level_embeds)
@staticmethod
def flatten_multi_level(multi_level_elements):
multi_level_elements = torch.cat([e.flatten(-2) for e in multi_level_elements], -1) # (b, [c], s)
if multi_level_elements.ndim == 3:
multi_level_elements.transpose_(1, 2)
return multi_level_elements
def get_lvl_pos_embed(self, multi_level_pos_embeds):
multi_level_pos_embeds = [
p + l.view(1, -1, 1, 1) for p, l in zip(multi_level_pos_embeds, self.level_embeds)
]
return self.flatten_multi_level(multi_level_pos_embeds)
def multi_level_misc(self, multi_level_masks):
if torchvision._is_tracing():
# torch.Tensor.shape exports not well for ONNX
# use operators.shape_as_tensor istead
from torch.onnx import operators
spatial_shapes = [operators.shape_as_tensor(m)[-2:] for m in multi_level_masks]
spatial_shapes = torch.stack(spatial_shapes).to(multi_level_masks[0].device)
else:
spatial_shapes = [m.shape[-2:] for m in multi_level_masks]
spatial_shapes = torch.as_tensor(spatial_shapes, device=multi_level_masks[0].device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratios(m) for m in multi_level_masks], 1)
return spatial_shapes, level_start_index, valid_ratios
@staticmethod
def get_valid_ratios(mask):
_, h, w = mask.shape
valid_h = torch.sum(~mask[:, :, 0], 1)
valid_w = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_h.float() / h
valid_ratio_w = valid_w.float() / w
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) # [n, 2]
return valid_ratio
class TwostageTransformer(DETRBaseTransformer):
"""A base class that contains some methods commonly used in two-stage transformer,
such as DeformableTransformer, DabTransformer, DINOTransformer, AlignTransformer.
"""
def __init__(self, num_feature_levels, embed_dim):
super().__init__(num_feature_levels, embed_dim)
self.enc_output = nn.Linear(embed_dim, embed_dim)
self.enc_output_norm = nn.LayerNorm(embed_dim)
self._init_weights_two_stage_transformer()
def _init_weights_two_stage_transformer(self):
nn.init.xavier_uniform_(self.enc_output.weight)
nn.init.constant_(self.enc_output.bias, 0.0)
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
n, s, c = memory.shape
proposals = []
cur = 0
if torchvision._is_tracing():
# avoid iteration warning on torch.Tensor
# convert Tensor to list[Tensor] instead
spatial_shapes = [b.unbind(0) for b in spatial_shapes.unbind(0)]
else:
# use list to avoid small kernel launching when indexing spatial shapes
spatial_shapes = spatial_shapes.tolist()
for lvl, (h, w) in enumerate(spatial_shapes):
mask_flatten = memory_padding_mask[:, cur:(cur + h * w)].view(n, h, w, 1)
valid_h = torch.sum(~mask_flatten[:, :, 0, 0], 1)
valid_w = torch.sum(~mask_flatten[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, h - 1, h, dtype=torch.float32, device=memory.device),
torch.linspace(0, w - 1, w, dtype=torch.float32, device=memory.device),
indexing="ij",
)
grid = torch.stack([grid_x, grid_y], -1) # [h, w, 2]
scale = torch.stack([valid_w, valid_h], -1).view(n, 1, 1, 2)
grid = (grid.expand(n, -1, -1, -1) + 0.5) / scale # [n, h, w, 2]
wh = torch.ones_like(grid) * 0.05 * 2.0**lvl
proposal = torch.cat([grid, wh], -1).view(n, -1, 4)
proposals.append(proposal)
cur += h * w
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse_sigmoid
output_proposals.masked_fill_(
memory_padding_mask.unsqueeze(-1) | ~output_proposals_valid, float("inf")
)
output_memory = memory * (~memory_padding_mask.unsqueeze(-1)) * (output_proposals_valid)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.init_weights()
def init_weights(self):
for layer in self.layers:
nn.init.xavier_uniform_(layer.weight)
nn.init.constant_(layer.bias, 0.0)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class SqueezeAndExcitation(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.conv_mask = nn.Conv2d(channels, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.se_module = nn.Sequential(
nn.Conv2d(channels, channels // reduction, kernel_size=1, stride=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, kernel_size=1, stride=1, bias=False),
nn.Sigmoid(),
)
nn.init.kaiming_normal_(self.conv_mask.weight, mode="fan_in", nonlinearity="relu")
def forward(self, x):
batch, channel, height, width = x.shape
# spatial pool
# b, 1, c, h * w
input_x = x.view(batch, channel, height * width).unsqueeze(1)
# b, 1, h * w, 1
context_mask = self.conv_mask(x).view(batch, 1, height * width)
context_mask = self.softmax(context_mask).unsqueeze(-1)
# b, 1, c, 1
context = torch.matmul(input_x, context_mask)
context = context.view(batch, channel, 1, 1)
return self.se_module(context) * x
class ContextBlock(nn.Module):
"""ContextBlock module in GCNet.
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
(https://arxiv.org/abs/1904.11492) for details.
Args:
in_channels (int): Channels of the input feature map.
ratio (float): Ratio of channels of transform bottleneck
pooling_type (str): Pooling method for context modeling.
Options are 'att' and 'avg', stand for attention pooling and
average pooling respectively. Default: 'att'.
fusion_types (Sequence[str]): Fusion method for feature fusion,
Options are 'channels_add', 'channel_mul', stand for channelwise
addition and multiplication respectively. Default: ('channel_add',)
"""
def __init__(
self,
in_channels: int,
ratio: float,
pooling_type: str = "att",
fusion_types: tuple = ("channel_add",),
):
super().__init__()
assert pooling_type in ["avg", "att"]
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ["channel_add", "channel_mul"]
assert all([f in valid_fusion_types for f in fusion_types])
assert len(fusion_types) > 0, "at least one fusion should be used"
self.in_channels = in_channels
self.ratio = ratio
self.planes = int(in_channels * ratio)
self.pooling_type = pooling_type
self.fusion_types = fusion_types
if pooling_type == "att":
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if "channel_add" in fusion_types:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1),
)
else:
self.channel_add_conv = None
if "channel_mul" in fusion_types:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1),
)
else:
self.channel_mul_conv = None
self.reset_parameters()
def reset_parameters(self):
if self.pooling_type == "att":
nn.init.kaiming_normal_(self.conv_mask.weight, mode="fan_in")
nn.init.constant_(self.conv_mask.bias, 0)
if self.channel_add_conv is not None:
nn.init.constant_(self.channel_add_conv[-1].weight, 0)
nn.init.constant_(self.channel_add_conv[-1].bias, 0)
if self.channel_mul_conv is not None:
nn.init.constant_(self.channel_mul_conv[-1].weight, 0)
nn.init.constant_(self.channel_mul_conv[-1].bias, 0)
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size()
if self.pooling_type == "att":
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out
import math
from typing import Tuple, Union
import torch
from torch import nn
from torchvision.ops import DeformConv2d
class DeformConv2dPack(nn.Module):
"""This is a pack of deformable convolution that can be used as normal convolution"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: Union[bool, str] = True,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (
kernel_size,
kernel_size,
)
self.in_channels = in_channels
self.kernel_size = kernel_size
self.groups = groups
self.conv_offset = nn.Conv2d(
in_channels,
groups * 2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups, # Don't know whether to add groups here
bias=True,
)
self.conv_mask = nn.Conv2d(
in_channels,
groups * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
)
self.deform_conv2d = DeformConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.init_weights()
def init_weights(self) -> None:
self.conv_offset.weight.data.zero_()
self.conv_mask.weight.data.zero_()
self.conv_offset.bias.data.zero_()
self.conv_mask.bias.data.zero_()
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1.0 / math.sqrt(n)
self.deform_conv2d.weight.data.uniform_(-stdv, stdv)
if self.deform_conv2d.bias is not None:
self.deform_conv2d.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
offset = self.conv_offset(x)
mask = torch.sigmoid(self.conv_mask(x))
out = self.deform_conv2d(x, offset, mask)
return out
import torch
from torch import nn
from torchvision.ops import boxes as box_ops
from util.misc import inverse_sigmoid
class GenerateDNQueries(nn.Module):
def __init__(
self,
num_queries: int = 300,
num_classes: int = 80,
label_embed_dim: int = 256,
denoising_groups: int = 5,
label_noise_prob: float = 0.2,
box_noise_scale: float = 0.4,
with_indicator: bool = False,
):
"""Generate denoising queries for DN-DETR
:param num_queries: Number of total queries in DN-DETR, defaults to 300
:param num_classes: Number of total categories, defaults to 80
:param label_embed_dim: The embedding dimension for label encoding, defaults to 256
:param denoising_groups: Number of noised ground truth groups, defaults to 5
:param label_noise_prob: The probability of the label being noised, defaults to 0.2
:param box_noise_scale: Scaling factor for box noising, defaults to 0.4
:param with_indicator: Whether to add indicator in noised label/box queries, defaults to False
"""
super(GenerateDNQueries, self).__init__()
self.num_queries = num_queries
self.num_classes = num_classes
self.label_embed_dim = label_embed_dim
self.denoising_groups = denoising_groups
self.label_noise_prob = label_noise_prob
self.box_noise_scale = box_noise_scale
self.with_indicator = with_indicator
# leave one dim for indicator mentioned in DN-DETR
if with_indicator:
self.label_encoder = nn.Embedding(num_classes, label_embed_dim - 1)
else:
self.label_encoder = nn.Embedding(num_classes, label_embed_dim)
@staticmethod
def apply_label_noise(labels: torch.Tensor, label_noise_prob: float = 0.2, num_classes: int = 80):
if label_noise_prob > 0:
mask = torch.rand_like(labels.float()) < label_noise_prob
noised_labels = torch.randint_like(labels, 0, num_classes)
noised_labels = torch.where(mask, noised_labels, labels)
return noised_labels
else:
return labels
@staticmethod
def apply_box_noise(boxes: torch.Tensor, box_noise_scale: float = 0.4):
if box_noise_scale > 0:
diff = torch.zeros_like(boxes)
diff[:, :2] = boxes[:, 2:] / 2
diff[:, 2:] = boxes[:, 2:]
boxes += torch.mul((torch.rand_like(boxes) * 2 - 1.0), diff) * box_noise_scale
boxes = boxes.clamp(min=0.0, max=1.0)
return boxes
def generate_query_masks(self, max_gt_num_per_image, device):
noised_query_nums = max_gt_num_per_image * self.denoising_groups
tgt_size = noised_query_nums + self.num_queries
attn_mask = torch.zeros(tgt_size, tgt_size, device=device, dtype=torch.bool)
# match query cannot see the reconstruct
attn_mask[noised_query_nums:, :noised_query_nums] = True
for i in range(self.denoising_groups):
start_col = start_row = max_gt_num_per_image * i
end_col = end_row = max_gt_num_per_image * (i + 1)
assert noised_query_nums >= end_col and start_col >= 0, "check attn_mask"
attn_mask[start_row:end_row, :start_col] = True
attn_mask[start_row:end_row, end_col:noised_query_nums] = True
return attn_mask
def forward(self, gt_labels_list, gt_boxes_list):
"""
:param gt_labels_list: Ground truth bounding boxes per image
with normalized coordinates in format ``(x, y, w, h)`` in shape ``(num_gts, 4)`
:param gt_boxes_list: Classification labels per image in shape ``(num_gt, )``
:return: Noised label queries, box queries, attention mask and denoising metas.
"""
# concat ground truth labels and boxes in one batch
# e.g. [tensor([0, 1, 2]), tensor([2, 3, 4])] -> tensor([0, 1, 2, 2, 3, 4])
gt_labels = torch.cat(gt_labels_list)
gt_boxes = torch.cat(gt_boxes_list)
# For efficient denoising, repeat the original ground truth labels and boxes to
# create more training denoising samples.
# e.g. tensor([0, 1, 2, 2, 3, 4]) -> tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4]) if group = 2.
gt_labels = gt_labels.repeat(self.denoising_groups, 1).flatten()
gt_boxes = gt_boxes.repeat(self.denoising_groups, 1)
# set the device as "gt_labels"
device = gt_labels.device
assert len(gt_labels_list) == len(gt_boxes_list)
batch_size = len(gt_labels_list)
# the number of ground truth per image in one batch
# e.g. [tensor([0, 1]), tensor([2, 3, 4])] -> gt_nums_per_image: [2, 3]
# means there are 2 instances in the first image and 3 instances in the second image
gt_nums_per_image = [x.numel() for x in gt_labels_list]
# Add noise on labels and boxes
noised_labels = self.apply_label_noise(gt_labels, self.label_noise_prob, self.num_classes)
noised_boxes = self.apply_box_noise(gt_boxes, self.box_noise_scale)
noised_boxes = inverse_sigmoid(noised_boxes)
# encoding labels
label_embedding = self.label_encoder(noised_labels)
query_num = label_embedding.shape[0]
# add indicator to label encoding if with_indicator == True
if self.with_indicator:
label_embedding = torch.cat([label_embedding, torch.ones([query_num, 1], device=device)], 1)
# calculate the max number of ground truth in one image inside the batch.
# e.g. gt_nums_per_image = [2, 3] which means
# the first image has 2 instances and the second image has 3 instances
# then the max_gt_num_per_image should be 3.
max_gt_num_per_image = max(gt_nums_per_image)
# the total denoising queries is depended on denoising groups and max number of instances.
noised_query_nums = max_gt_num_per_image * self.denoising_groups
# initialize the generated noised queries to zero.
# And the zero initialized queries will be assigned with noised embeddings later.
noised_label_queries = torch.zeros(batch_size, noised_query_nums, self.label_embed_dim, device=device)
noised_box_queries = torch.zeros(batch_size, noised_query_nums, 4, device=device)
# batch index per image: [0, 1, 2, 3] for batch_size == 4
batch_idx = torch.arange(0, batch_size)
# e.g. gt_nums_per_image = [2, 3]
# batch_idx = [0, 1]
# then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1]
# which indicates which image the instance belongs to.
# cuz the instances has been flattened before.
batch_idx_per_instance = torch.repeat_interleave(batch_idx, torch.tensor(gt_nums_per_image).long())
# indicate which image the noised labels belong to. For example:
# noised label: tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4])
# batch_idx_per_group: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
# which means the first label "tensor([0])"" belongs to "image_0".
batch_idx_per_group = batch_idx_per_instance.repeat(self.denoising_groups, 1).flatten()
# Cuz there might be different numbers of ground truth in each image of the same batch.
# So there might be some padding part in noising queries.
# Here we calculate the indexes for the valid queries and
# fill them with the noised embeddings.
# And leave the padding part to zeros.
if len(gt_nums_per_image):
valid_index_per_group = torch.cat([torch.arange(num) for num in gt_nums_per_image])
valid_index_per_group = torch.cat([
valid_index_per_group + max_gt_num_per_image * i for i in range(self.denoising_groups)
]).long()
if len(batch_idx_per_group):
noised_label_queries[(batch_idx_per_group, valid_index_per_group)] = label_embedding
noised_box_queries[(batch_idx_per_group, valid_index_per_group)] = noised_boxes
# generate attention masks for transformer layers
attn_mask = self.generate_query_masks(max_gt_num_per_image, device)
return (
noised_label_queries,
noised_box_queries,
attn_mask,
self.denoising_groups,
max_gt_num_per_image,
)
class GenerateCDNQueries(GenerateDNQueries):
def __init__(
self,
num_queries: int = 300,
num_classes: int = 80,
label_embed_dim: int = 256,
denoising_nums: int = 100,
label_noise_prob: float = 0.5,
box_noise_scale: float = 1.0,
):
super().__init__(
num_queries=num_queries,
num_classes=num_classes,
label_embed_dim=label_embed_dim,
label_noise_prob=label_noise_prob,
box_noise_scale=box_noise_scale,
denoising_groups=1,
)
self.denoising_nums = denoising_nums
self.label_encoder = nn.Embedding(num_classes, label_embed_dim)
def apply_box_noise(self, boxes: torch.Tensor, box_noise_scale: float = 0.4):
"""
:param boxes: Bounding boxes in format ``(x_c, y_c, w, h)`` with shape ``(num_boxes, 4)``
:param box_noise_scale: Scaling factor for box noising, defaults to 0.4
:return: Noised boxes
"""
num_boxes = len(boxes) // self.denoising_groups // 2
positive_idx = torch.arange(num_boxes, dtype=torch.long, device=boxes.device)
positive_idx = positive_idx.unsqueeze(0).repeat(self.denoising_groups, 1)
positive_idx += (
torch.arange(self.denoising_groups, dtype=torch.long, device=boxes.device).unsqueeze(1) *
num_boxes * 2
)
positive_idx = positive_idx.flatten()
negative_idx = positive_idx + num_boxes
if box_noise_scale > 0:
diff = torch.zeros_like(boxes)
diff[:, :2] = boxes[:, 2:] / 2
diff[:, 2:] = boxes[:, 2:] / 2
rand_sign = torch.randint_like(boxes, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
rand_part = torch.rand_like(boxes)
rand_part[negative_idx] += 1.0
rand_part *= rand_sign
xyxy_boxes = box_ops._box_cxcywh_to_xyxy(boxes)
xyxy_boxes += torch.mul(rand_part, diff) * box_noise_scale
xyxy_boxes = xyxy_boxes.clamp(min=0.0, max=1.0)
boxes = box_ops._box_xyxy_to_cxcywh(xyxy_boxes)
return boxes
def forward(self, gt_labels_list, gt_boxes_list):
"""_summary_
:param gt_labels_list: Ground truth bounding boxes per image
with normalized coordinates in format ``(x, y, w, h)`` in shape ``(num_gts, 4)``
:param gt_boxes_list: Classification labels per image in shape ``(num_gt, )``
:return: Noised label queries, box queries, attention mask and denoising metas.
"""
# the number of ground truth per image in one batch
# e.g. [tensor([0, 1]), tensor([2, 3, 4])] -> gt_nums_per_image: [2, 3]
# means there are 2 instances in the first image and 3 instances in the second image
gt_nums_per_image = [x.numel() for x in gt_labels_list]
# calculate the max number of ground truth in one image inside the batch.
# e.g. gt_nums_per_image = [2, 3] which means
# the first image has 2 instances and the second image has 3 instances
# then the max_gt_num_per_image should be 3.
max_gt_num_per_image = max(gt_nums_per_image)
# get denoising_groups, which is 1 for empty ground truth
denoising_groups = self.denoising_nums * max_gt_num_per_image // max(max_gt_num_per_image**2, 1)
self.denoising_groups = max(denoising_groups, 1)
# concat ground truth labels and boxes in one batch
# e.g. [tensor([0, 1, 2]), tensor([2, 3, 4])] -> tensor([0, 1, 2, 2, 3, 4])
gt_labels = torch.cat(gt_labels_list)
gt_boxes = torch.cat(gt_boxes_list)
# For efficient denoising, repeat the original ground truth labels and boxes to
# create more training denoising samples.
# each group has positive and negative. e.g. if group = 2, tensor([0, 1, 2, 2, 3, 4]) ->
# tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4]).
gt_labels = gt_labels.repeat(self.denoising_groups * 2, 1).flatten()
gt_boxes = gt_boxes.repeat(self.denoising_groups * 2, 1)
# set the device as "gt_labels"
device = gt_labels.device
assert len(gt_labels_list) == len(gt_boxes_list)
batch_size = len(gt_labels_list)
# Add noise on labels and boxes
noised_labels = self.apply_label_noise(gt_labels, self.label_noise_prob * 0.5, self.num_classes)
noised_boxes = self.apply_box_noise(gt_boxes, self.box_noise_scale)
noised_boxes = inverse_sigmoid(noised_boxes)
# encoding labels
label_embedding = self.label_encoder(noised_labels)
# the total denoising queries is depended on denoising groups and max number of instances.
noised_query_nums = max_gt_num_per_image * self.denoising_groups * 2
# initialize the generated noised queries to zero.
# And the zero initialized queries will be assigned with noised embeddings later.
noised_label_queries = torch.zeros(batch_size, noised_query_nums, self.label_embed_dim, device=device)
noised_box_queries = torch.zeros(batch_size, noised_query_nums, 4, device=device)
# batch index per image: [0, 1, 2, 3] for batch_size == 4
batch_idx = torch.arange(0, batch_size)
# e.g. gt_nums_per_image = [2, 3]
# batch_idx = [0, 1]
# then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1]
# which indicates which image the instance belongs to.
# cuz the instances has been flattened before.
batch_idx_per_instance = torch.repeat_interleave(
batch_idx, torch.tensor(gt_nums_per_image, dtype=torch.long)
)
# indicate which image the noised labels belong to. For example:
# noised label: tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4])
# batch_idx_per_group: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
# which means the first label "tensor([0])"" belongs to "image_0".
batch_idx_per_group = batch_idx_per_instance.repeat(self.denoising_groups * 2, 1).flatten()
# Cuz there might be different numbers of ground truth in each image of the same batch.
# So there might be some padding part in noising queries.
# Here we calculate the indexes for the valid queries and
# fill them with the noised embeddings.
# And leave the padding part to zeros.
if len(gt_nums_per_image):
valid_index_per_group = torch.cat([torch.arange(num) for num in gt_nums_per_image])
valid_index_per_group = torch.cat([
valid_index_per_group + max_gt_num_per_image * i for i in range(self.denoising_groups * 2)
]).long()
if len(batch_idx_per_group):
noised_label_queries[(batch_idx_per_group, valid_index_per_group)] = label_embedding
noised_box_queries[(batch_idx_per_group, valid_index_per_group)] = noised_boxes
# generate attention masks for transformer layers
attn_mask = self.generate_query_masks(2 * max_gt_num_per_image, device)
return (
noised_label_queries,
noised_box_queries,
attn_mask,
self.denoising_groups,
max_gt_num_per_image * 2,
)
from torch.nn import functional as F
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
prob = inputs.sigmoid()
target_score = targets.to(inputs.dtype)
weight = (1 - alpha) * prob**gamma * (1 - targets) + targets * alpha * (1 - prob)**gamma
# according to original implementation, sigmoid_focal_loss keep gradient on weight
loss = F.binary_cross_entropy_with_logits(inputs, target_score, reduction="none")
loss = loss * weight
# we use sum/num to replace mean to avoid NaN
return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes
def vari_sigmoid_focal_loss(inputs, targets, gt_score, num_boxes, alpha: float = 0.25, gamma: float = 2):
prob = inputs.sigmoid().detach() # pytorch version of RT-DETR has detach while paddle version not
target_score = targets * gt_score.unsqueeze(-1)
weight = (1 - alpha) * prob.pow(gamma) * (1 - targets) + target_score
loss = F.binary_cross_entropy_with_logits(inputs, target_score, weight=weight, reduction="none")
# we use sum/num to replace mean to avoid NaN
return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes
def ia_bce_loss(inputs, targets, gt_score, num_boxes, k: float = 0.25, alpha: float = 0, gamma: float = 2):
prob = inputs.sigmoid().detach()
# calculate iou_aware_score and constrain the value following original implementation
iou_aware_score = prob**k * gt_score.unsqueeze(-1)**(1 - k)
iou_aware_score = iou_aware_score.clamp(min=0.01)
target_score = targets * iou_aware_score
weight = (1 - alpha) * prob.pow(gamma) * (1 - targets) + targets
loss = F.binary_cross_entropy_with_logits(inputs, target_score, weight=weight, reduction="none")
# we use sum/num to replace mean to avoid NaN
return (loss.sum(1) / max(loss.shape[1], 1)).sum() / num_boxes
"""These modules are borrowed from torchvision.ops.misc, for usage in low-version pytorch and torchvision"""
import warnings
from typing import Callable, List, Optional
import torch
from torch import Tensor
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
class ConvNormActivation(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if bias is None:
bias = norm_layer is None
layers = [
conv_layer(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias=bias,
)
]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
params = {} if inplace is None else {"inplace": inplace}
layers.append(activation_layer(**params))
super().__init__(*layers)
self.out_channels = out_channels
if self.__class__ == ConvNormActivation:
warnings.warn(
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
)
class Conv2dNormActivation(ConvNormActivation):
"""
Configurable block used for Convolution2d-Normalization-Activation blocks.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
torch.nn.Conv2d,
)
class Conv3dNormActivation(ConvNormActivation):
"""
Configurable block used for Convolution3d-Normalization-Activation blocks.
Args:
in_channels (int): Number of channels in the input video.
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
torch.nn.Conv3d,
)
class SqueezeExcitation(torch.nn.Module):
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
Args:
input_channels (int): Number of channels in the input image
squeeze_channels (int): Number of squeeze channels
activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
"""
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None:
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
class Permute(torch.nn.Module):
"""This module returns a view of the tensor input with its dimensions permuted.
Args:
dims (List[int]): The desired ordering of dimensions
"""
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x: Tensor) -> Tensor:
return torch.permute(x, self.dims)
import math
import os
import warnings
import torch
import torchvision
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_uniform_
from torch.utils.cpp_extension import load
_C = None
if torch.cuda.is_available():
try:
_C = load(
"MultiScaleDeformableAttention",
sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
extra_cflags=["-O2"],
verbose=True,
)
except Exception as e:
warnings.warn(f"Failed to load MultiScaleDeformableAttention C++ extension: {e}")
else:
warnings.warn("No cuda is available, skip loading MultiScaleDeformableAttention C++ extention")
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step
output = _C.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def bilinear_grid_sample(im, grid, align_corners=False):
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input's
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input's corner pixels,
making the sampling more resolution agnostic.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
x0 = torch.clamp_(x0, 0, padded_w - 1)
x1 = torch.clamp_(x1, 0, padded_w - 1)
y0 = torch.clamp_(y0, 0, padded_h - 1)
y1 = torch.clamp_(y1, 0, padded_h - 1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split(value_spatial_shapes.prod(-1).unbind(0), dim=1)
sampling_grids = 2 * sampling_locations - 1
if torchvision._is_tracing():
# avoid iteration warning on torch.Tensor
# convert Tensor to list[Tensor] instead
value_spatial_shapes = [b.unbind(0) for b in value_spatial_shapes.unbind(0)]
else:
# use list to avoid small kernel launching when indexing spatial shapes
value_spatial_shapes = value_spatial_shapes.tolist()
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
if torchvision._is_tracing():
sampling_value_l_ = bilinear_grid_sample(
value_l_,
sampling_grid_l_.contiguous(),
align_corners=False,
)
else:
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = torch.stack(sampling_value_list, dim=-2).flatten(-2)
output = (output * attention_weights).sum(-1)
output = output.view(bs, num_heads * embed_dims, num_queries)
return output.transpose(1, 2).contiguous()
class MultiScaleDeformableAttention(nn.Module):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
"""
def __init__(
self,
embed_dim: int = 256,
num_levels: int = 4,
num_heads: int = 8,
num_points: int = 4,
img2col_step: int = 64,
):
"""Initialization function of MultiScaleDeformableAttention
:param embed_dim: The embedding dimension of Attention, defaults to 256
:param num_levels: The number of feature map used in Attention, defaults to 4
:param num_heads: The number of attention heads, defaults to 8
:param num_points: The number of sampling points for each query
in each head, defaults to 4
:param img2col_step: The step used in image_to_column, defaults to 64
"""
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
"embed_dim must be divisible by num_heads, but got {} and {}".format(embed_dim, num_heads)
)
head_dim = embed_dim // num_heads
if not _is_power_of_2(head_dim):
warnings.warn(
"""
You'd better set embed_dim in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self.im2col_step = img2col_step
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
# num_heads * num_points and num_levels for multi-level feature inputs
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def init_weights(self):
"""Default initialization for parameters of the module"""
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.num_heads, dtype=torch.float32)
thetas = thetas * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
grid_init = grid_init.view(self.num_heads, 1, 1, 2)
grid_init = grid_init.repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def forward(
self,
query: Tensor,
reference_points: Tensor,
value: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
key_padding_mask: Tensor,
) -> Tensor:
"""Forward function of MultiScaleDeformableAttention
:param query: query embeddings with shape (batch_size, num_query, embed_dim)
:param reference_points: the normalized reference points with shape
(batch_size, num_query, num_levels, 2), all_elements is range in [0, 1],
top-left (0, 0), bottom-right (1, 1), including padding area. or
(batch_size, num_query, num_levels, 4), add additional two dimensions (h, w)
to form reference boxes
:param value: value embeddings with shape (batch_size, num_value, embed_dim)
:param spatial_shapes: spatial shapes of features in different levels.
with shape (num_levels, 2), last dimension represents (h, w)
:param level_start_index: the start index of each level. A tensor with shape
(num_levels,), which can be represented as [0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]
:param key_padding_mask: ByteTensor for query, with shape (batch_size, num_value)
:return: forward results with shape (batch_size, num_query, embed_dim)
"""
batch_size, num_query, _ = query.shape
batch_size, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
# value projection
value = self.value_proj(value)
# fill "0" for the padding part
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(batch_size, num_value, self.num_heads, self.embed_dim // self.num_heads)
sampling_offsets = self.sampling_offsets(query).view(
batch_size, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
# total num_levels * num_points features
attention_weights = self.attention_weights(query).view(
batch_size, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(
batch_size,
num_query,
self.num_heads,
self.num_levels,
self.num_points,
)
# batch_size, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :] +
sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] +
sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
# the original impl for fp32 training
if _C is not None and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value.to(torch.float32),
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
if value.dtype != torch.float32:
output = output.to(value.dtype)
output = self.output_proj(output)
return output
#include <vector>
#include "ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_cuda_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_cuda_backward, "ms_deform_attn_backward");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment