Commit 84b97851 authored by chenych's avatar chenych
Browse files

First commit

parents
Pipeline #656 failed with stages
in 0 seconds
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import argparse
import datetime
import json
import random
import time
import os
import torch
import datasets
import numpy as np
import util.misc as utils
import datasets.samplers as samplers
from pathlib import Path
from torch.utils.data import DataLoader
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
def get_args_parser():
parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument(
"--lr_backbone_names", default=["backbone.0"], type=str, nargs="+"
)
parser.add_argument("--lr_backbone", default=2e-5, type=float)
parser.add_argument(
"--lr_linear_proj_names",
default=["reference_points", "sampling_offsets"],
type=str,
nargs="+",
)
parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float)
parser.add_argument("--batch_size", default=4, type=int)
parser.add_argument("--weight_decay", default=1e-4, type=float)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--lr_drop", default=40, type=int)
parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+")
parser.add_argument(
"--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
)
parser.add_argument("--sgd", action="store_true")
# Variants of Deformable DETR
parser.add_argument("--with_box_refine", default=False, action="store_true")
parser.add_argument("--two_stage", default=False, action="store_true")
# Model parameters
parser.add_argument(
"--frozen_weights",
type=str,
default=None,
help="Path to the pretrained model. If set, only the mask head will be trained",
)
# * Backbone
parser.add_argument(
"--backbone",
default="resnet50",
type=str,
help="Name of the convolutional backbone to use",
)
parser.add_argument(
"--dilation",
action="store_true",
help="If true, we replace stride with dilation in the last convolutional block (DC5)",
)
parser.add_argument(
"--position_embedding",
default="sine",
type=str,
choices=("sine", "learned"),
help="Type of positional embedding to use on top of the image features",
)
parser.add_argument(
"--position_embedding_scale",
default=2 * np.pi,
type=float,
help="position / size * scale",
)
parser.add_argument(
"--num_feature_levels", default=4, type=int, help="number of feature levels"
)
# swin backbone
parser.add_argument(
"--pretrained_backbone_path",
default="./swin_tiny_patch4_window7_224.pkl",
type=str,
)
parser.add_argument("--drop_path_rate", default=0.2, type=float)
# * Transformer
parser.add_argument(
"--enc_layers",
default=6,
type=int,
help="Number of encoding layers in the transformer",
)
parser.add_argument(
"--dec_layers",
default=6,
type=int,
help="Number of decoding layers in the transformer",
)
parser.add_argument(
"--dim_feedforward",
default=2048,
type=int,
help="Intermediate size of the feedforward layers in the transformer blocks",
)
parser.add_argument(
"--hidden_dim",
default=256,
type=int,
help="Size of the embeddings (dimension of the transformer)",
)
parser.add_argument(
"--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
)
parser.add_argument(
"--nheads",
default=8,
type=int,
help="Number of attention heads inside the transformer's attentions",
)
parser.add_argument(
"--num_queries_one2one",
default=300,
type=int,
help="Number of query slots for one-to-one matching",
)
parser.add_argument(
"--num_queries_one2many",
default=0,
type=int,
help="Number of query slots for one-to-many matchining",
)
parser.add_argument("--dec_n_points", default=4, type=int)
parser.add_argument("--enc_n_points", default=4, type=int)
# Deformable DETR tricks
parser.add_argument("--mixed_selection", action="store_true", default=False)
parser.add_argument("--look_forward_twice", action="store_true", default=False)
# hybrid branch
parser.add_argument("--k_one2many", default=5, type=int)
parser.add_argument("--lambda_one2many", default=1.0, type=float)
# * Segmentation
parser.add_argument(
"--masks",
action="store_true",
help="Train segmentation head if the flag is provided",
)
# Loss
parser.add_argument(
"--no_aux_loss",
dest="aux_loss",
action="store_false",
help="Disables auxiliary decoding losses (loss at each layer)",
)
# * Matcher
parser.add_argument(
"--set_cost_class",
default=2,
type=float,
help="Class coefficient in the matching cost",
)
parser.add_argument(
"--set_cost_bbox",
default=5,
type=float,
help="L1 box coefficient in the matching cost",
)
parser.add_argument(
"--set_cost_giou",
default=2,
type=float,
help="giou box coefficient in the matching cost",
)
# * Loss coefficients
parser.add_argument("--mask_loss_coef", default=1, type=float)
parser.add_argument("--dice_loss_coef", default=1, type=float)
parser.add_argument("--cls_loss_coef", default=2, type=float)
parser.add_argument("--bbox_loss_coef", default=5, type=float)
parser.add_argument("--giou_loss_coef", default=2, type=float)
parser.add_argument("--focal_alpha", default=0.25, type=float)
# dataset parameters
parser.add_argument("--dataset_file", default="coco")
parser.add_argument("--coco_path", default="/public/DL_DATA/COCO2017", type=str)
parser.add_argument("--coco_panoptic_path", type=str)
parser.add_argument("--remove_difficult", action="store_true")
parser.add_argument(
"--output_dir", default="", help="path where to save, empty for no saving"
)
parser.add_argument(
"--device", default="cuda", help="device to use for training / testing"
)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--resume", default="", help="resume from checkpoint")
parser.add_argument(
"--start_epoch", default=0, type=int, metavar="N", help="start epoch"
)
parser.add_argument("--num_workers", default=2, type=int)
parser.add_argument(
"--cache_mode",
default=False,
action="store_true",
help="whether to cache images on memory",
)
# * eval technologies
parser.add_argument("--eval", action="store_true")
# eval in training set
parser.add_argument("--eval_in_training_set", default=False, action="store_true")
# topk for eval
parser.add_argument("--topk", default=100, type=int)
# * training technologies
parser.add_argument("--use_fp16", default=False, action="store_true")
parser.add_argument("--use_checkpoint", default=False, action="store_true")
# * logging technologies
parser.add_argument("--use_wandb", action="store_true", default=False)
return parser
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)
dataset_train = build_dataset(image_set="train", args=args)
if not args.eval_in_training_set:
dataset_val = build_dataset(
image_set="val", args=args, eval_in_training_set=False,
)
else:
print("eval in the training set")
dataset_val = build_dataset(
image_set="train", args=args, eval_in_training_set=True,
)
if args.distributed:
if args.cache_mode:
sampler_train = samplers.NodeDistributedSampler(dataset_train)
sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = samplers.DistributedSampler(dataset_train)
sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True
)
data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
def match_name_keywords(n, name_keywords):
out = False
for b in name_keywords:
if b in n:
out = True
break
return out
for n, p in model_without_ddp.named_parameters():
print(n)
param_dicts = [
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if not match_name_keywords(n, args.lr_backbone_names)
and not match_name_keywords(n, args.lr_linear_proj_names)
and p.requires_grad
],
"lr": args.lr,
},
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad
],
"lr": args.lr_backbone,
},
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad
],
"lr": args.lr * args.lr_linear_proj_mult,
},
]
if args.sgd:
optimizer = torch.optim.SGD(
param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
)
else:
optimizer = torch.optim.AdamW(
param_dicts, lr=args.lr, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.dataset_file == "coco_panoptic":
# We also evaluate AP during panoptic training, on original coco DS
coco_val = datasets.coco.build("val", args)
base_ds = get_coco_api_from_dataset(coco_val)
else:
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
model_without_ddp.detr.load_state_dict(checkpoint["model"])
output_dir = Path(args.output_dir)
if args.resume and os.path.exists(args.resume):
if args.resume.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location="cpu", check_hash=True
)
else:
checkpoint = torch.load(args.resume, map_location="cpu")
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
checkpoint["model"], strict=False
)
unexpected_keys = [
k
for k in unexpected_keys
if not (k.endswith("total_params") or k.endswith("total_ops"))
]
if len(missing_keys) > 0:
print("Missing Keys: {}".format(missing_keys))
if len(unexpected_keys) > 0:
print("Unexpected Keys: {}".format(unexpected_keys))
if (
not args.eval
and "optimizer" in checkpoint
and "lr_scheduler" in checkpoint
and "epoch" in checkpoint
):
import copy
p_groups = copy.deepcopy(optimizer.param_groups)
optimizer.load_state_dict(checkpoint["optimizer"])
for pg, pg_old in zip(optimizer.param_groups, p_groups):
pg["lr"] = pg_old["lr"]
pg["initial_lr"] = pg_old["initial_lr"]
print(optimizer.param_groups)
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
args.override_resumed_lr_drop = True
if args.override_resumed_lr_drop:
print(
"Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler."
)
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(
map(lambda group: group["initial_lr"], optimizer.param_groups)
)
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint["epoch"] + 1
# check the resumed model
if not args.eval:
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
if args.eval:
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
if args.output_dir:
utils.save_on_master(
coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth"
)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model,
criterion,
data_loader_train,
optimizer,
device,
epoch,
args.clip_max_norm,
k_one2many=args.k_one2many,
lambda_one2many=args.lambda_one2many,
use_wandb=args.use_wandb,
use_fp16=args.use_fp16,
)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / "checkpoint.pth"]
# extra checkpoint before LR drop and every 5 epochs
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
utils.save_on_master(
{
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
},
checkpoint_path,
)
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters,
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(output_dir / "eval").mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
torch.save(
coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name,
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"Deformable DETR training and evaluation script", parents=[get_args_parser()]
)
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
# -*- coding: utf-8 -*-
from .checkpoint import load_checkpoint
__all__ = ["load_checkpoint"]
# Copyright (c) Open-MMLab. All rights reserved.
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo
from torch.nn import functional as F
import mmcv
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.utils import mkdir_or_exist
from mmcv.runner import get_dist_info
ENV_MMCV_HOME = "MMCV_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"),
)
)
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=""):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict,
prefix,
local_metadata,
True,
all_missing_keys,
unexpected_keys,
err_msg,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
if unexpected_keys:
err_msg.append(
"unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n'
)
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n'
)
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
err_msg = "\n".join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url, model_dir=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
allowed_backends = ["ceph"]
if backend not in allowed_backends:
raise ValueError(f"Load from Backend {backend} is not supported.")
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f"torchvision.models.{name}")
if hasattr(_zoo, "model_urls"):
_urls = getattr(_zoo, "model_urls")
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json")
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, "open_mmlab.json")
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json")
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json")
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("backbone."):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith("modelzoo://"):
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith("torchvision://"):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith("open-mmlab://"):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(
f"open-mmlab://{model_name} is deprecated in favor "
f"of open-mmlab://{deprecated_urls[model_name]}"
)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(("http://", "https://")):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f"{filename} is not a checkpoint file")
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith("mmcls://"):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(("http://", "https://")):
checkpoint = load_url_dist(filename)
elif filename.startswith("pavi://"):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith("s3://"):
checkpoint = load_fileclient_dist(
filename, backend="ceph", map_location=map_location
)
else:
if not osp.isfile(filename):
raise IOError(f"{filename} is not a checkpoint file")
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
# get state_dict from checkpoint
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith("module."):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith("encoder"):
state_dict = {
k.replace("encoder.", ""): v
for k, v in state_dict.items()
if k.startswith("encoder.")
}
# reshape absolute position embedding
if state_dict.get("absolute_pos_embed") is not None:
absolute_pos_embed = state_dict["absolute_pos_embed"]
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning("Error in loading absolute_pos_embed, pass")
else:
state_dict["absolute_pos_embed"] = absolute_pos_embed.view(
N2, H, W, C2
).permute(0, 3, 1, 2)
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys() if "relative_position_bias_table" in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {table_key}, pass")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode="bicubic",
)
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(
1, 0
)
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix="", keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, "CLASSES") and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint["optimizer"] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint["optimizer"] = {}
for name, optim in optimizer.items():
checkpoint["optimizer"][name] = optim.state_dict()
if filename.startswith("pavi://"):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, "wb") as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import save_checkpoint
from .epoch_based_runner import EpochBasedRunnerAmp
__all__ = ["EpochBasedRunnerAmp", "save_checkpoint"]
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
from tempfile import TemporaryDirectory
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
try:
import apex
except:
print("apex is not installed")
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
``optimizer``, ``amp``. By default ``meta`` will contain version
and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, "CLASSES") and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint["optimizer"] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint["optimizer"] = {}
for name, optim in optimizer.items():
checkpoint["optimizer"][name] = optim.state_dict()
# save amp state dict in the checkpoint
checkpoint["amp"] = apex.amp.state_dict()
if filename.startswith("pavi://"):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, "wb") as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.runner import RUNNERS, EpochBasedRunner
from .checkpoint import save_checkpoint
try:
import apex
except:
print("apex is not installed")
@RUNNERS.register_module()
class EpochBasedRunnerAmp(EpochBasedRunner):
"""Epoch-based Runner with AMP support.
This runner train models epoch by epoch.
"""
def save_checkpoint(
self,
out_dir,
filename_tmpl="epoch_{}.pth",
save_optimizer=True,
meta=None,
create_symlink=True,
):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
elif isinstance(meta, dict):
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
raise TypeError(f"meta should be a dict or None, but got {type(meta)}")
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, "latest.pth")
if platform.system() != "Windows":
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
def resume(self, checkpoint, resume_optimizer=True, map_location="default"):
if map_location == "default":
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id),
)
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(checkpoint, map_location=map_location)
self._epoch = checkpoint["meta"]["epoch"]
self._iter = checkpoint["meta"]["iter"]
if "optimizer" in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint["optimizer"])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(checkpoint["optimizer"][k])
else:
raise TypeError(
"Optimizer should be dict or torch.optim.Optimizer "
f"but got {type(self.optimizer)}"
)
if "amp" in checkpoint:
apex.amp.load_state_dict(checkpoint["amp"])
self.logger.info("load amp state dict")
self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter)
# 模型唯一标识
modelCode=475
# 模型名称
modelName=hdetr_pytorch
# 模型描述
modelDescription=HDETR引入了一种混合匹配方案,这个新的匹配机制允许将多个查询分配给每个正样本,从而提高了训练效果,适用于多种视觉任务如目标检测、3D物体检测、姿势估计和对象跟踪等。
# 应用场景
appScenario=推理,训练,目标检测,网安,交通,政府
# 框架类型
frameType=PyTorch
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
from .deformable_detr import build, build_test
def build_model(args):
return build(args)
def build_test_model(args):
return build_test(args)
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
from .swin_transformer import SwinTransformer
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n, eps=1e-5):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = self.eps
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False)
if return_interm_layers:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048]
else:
return_layers = {"layer4": "0"}
self.strides = [32]
self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
norm_layer = FrozenBatchNorm2d
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=norm_layer,
)
assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
super().__init__(backbone, train_backbone, return_interm_layers)
if dilation:
self.strides[-1] = self.strides[-1] // 2
class TransformerBackbone(nn.Module):
def __init__(
self, backbone: str, train_backbone: bool, return_interm_layers: bool, args
):
super().__init__()
out_indices = (1, 2, 3)
if backbone == "swin_tiny":
backbone = SwinTransformer(
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 96
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_small":
backbone = SwinTransformer(
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 96
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_large":
backbone = SwinTransformer(
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 192
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_large_window12":
backbone = SwinTransformer(
pretrain_img_size=384,
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 192
backbone.init_weights(args.pretrained_backbone_path)
else:
raise NotImplementedError
for name, parameter in backbone.named_parameters():
# TODO: freeze some layers?
if not train_backbone:
parameter.requires_grad_(False)
if return_interm_layers:
self.strides = [8, 16, 32]
self.num_channels = [
embed_dim * 2,
embed_dim * 4,
embed_dim * 8,
]
else:
self.strides = [32]
self.num_channels = [embed_dim * 8]
self.body = backbone
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in sorted(xs.items()):
out.append(x)
# position encoding
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks or (args.num_feature_levels > 1)
if "resnet" in args.backbone:
backbone = Backbone(
args.backbone, train_backbone, return_interm_layers, args.dilation,
)
else:
backbone = TransformerBackbone(
args.backbone, train_backbone, return_interm_layers, args
)
model = Joiner(backbone, position_embedding)
return model
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Deformable DETR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
import math
from util import box_ops
from util.misc import (
NestedTensor,
nested_tensor_from_tensor_list,
accuracy,
get_world_size,
interpolate,
is_dist_avail_and_initialized,
inverse_sigmoid,
)
from .backbone import build_backbone
from .matcher import build_matcher
from .segmentation import (
DETRsegm,
PostProcessPanoptic,
PostProcessSegm,
dice_loss,
sigmoid_focal_loss,
)
from .deformable_transformer import build_deforamble_transformer
import copy
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class DeformableDETR(nn.Module):
""" This is the Deformable DETR module that performs object detection """
def __init__(
self,
backbone,
transformer,
num_classes,
num_feature_levels,
aux_loss=True,
with_box_refine=False,
two_stage=False,
num_queries_one2one=300,
num_queries_one2many=0,
mixed_selection=False,
):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
with_box_refine: iterative bounding box refinement
two_stage: two-stage Deformable DETR
num_queries_one2one: number of object queries for one-to-one matching part
num_queries_one2many: number of object queries for one-to-many matching part
mixed_selection: a trick for Deformable DETR two stage
"""
super().__init__()
num_queries = num_queries_one2one + num_queries_one2many
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels
if not two_stage:
self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
elif mixed_selection:
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if num_feature_levels > 1:
num_backbone_outs = len(backbone.strides)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.with_box_refine = with_box_refine
self.two_stage = two_stage
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (
(transformer.decoder.num_layers + 1)
if two_stage
else transformer.decoder.num_layers
)
if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement
self.transformer.decoder.bbox_embed = self.bbox_embed
else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList(
[self.class_embed for _ in range(num_pred)]
)
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.transformer.decoder.bbox_embed = None
if two_stage:
# hack implementation for two-stage
self.transformer.decoder.class_embed = self.class_embed
for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
self.num_queries_one2one = num_queries_one2one
self.mixed_selection = mixed_selection
def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if not isinstance(samples, NestedTensor):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
torch.bool
)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
pos.append(pos_l)
query_embeds = None
if not self.two_stage or self.mixed_selection:
query_embeds = self.query_embed.weight[0 : self.num_queries, :]
# make attn mask
""" attention mask to prevent information leakage
"""
self_attn_mask = (
torch.zeros([self.num_queries, self.num_queries,]).bool().to(src.device)
)
self_attn_mask[self.num_queries_one2one :, 0 : self.num_queries_one2one,] = True
self_attn_mask[0 : self.num_queries_one2one, self.num_queries_one2one :,] = True
(
hs,
init_reference,
inter_references,
enc_outputs_class,
enc_outputs_coord_unact,
) = self.transformer(srcs, masks, pos, query_embeds, self_attn_mask)
outputs_classes_one2one = []
outputs_coords_one2one = []
outputs_classes_one2many = []
outputs_coords_one2many = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[lvl](hs[lvl])
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid()
outputs_classes_one2one.append(
outputs_class[:, 0 : self.num_queries_one2one]
)
outputs_classes_one2many.append(
outputs_class[:, self.num_queries_one2one :]
)
outputs_coords_one2one.append(
outputs_coord[:, 0 : self.num_queries_one2one]
)
outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one :])
outputs_classes_one2one = torch.stack(outputs_classes_one2one)
outputs_coords_one2one = torch.stack(outputs_coords_one2one)
outputs_classes_one2many = torch.stack(outputs_classes_one2many)
outputs_coords_one2many = torch.stack(outputs_coords_one2many)
out = {
"pred_logits": outputs_classes_one2one[-1],
"pred_boxes": outputs_coords_one2one[-1],
"pred_logits_one2many": outputs_classes_one2many[-1],
"pred_boxes_one2many": outputs_coords_one2many[-1],
}
if self.aux_loss:
out["aux_outputs"] = self._set_aux_loss(
outputs_classes_one2one, outputs_coords_one2one
)
out["aux_outputs_one2many"] = self._set_aux_loss(
outputs_classes_one2many, outputs_coords_one2many
)
if self.two_stage:
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
out["enc_outputs"] = {
"pred_logits": enc_outputs_class,
"pred_boxes": enc_outputs_coord,
}
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.focal_alpha = focal_alpha
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert "pred_logits" in outputs
src_logits = outputs["pred_logits"]
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat(
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
)
target_classes = torch.full(
src_logits.shape[:2],
self.num_classes,
dtype=torch.int64,
device=src_logits.device,
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
dtype=src_logits.dtype,
layout=src_logits.layout,
device=src_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
sigmoid_focal_loss(
src_logits,
target_classes_onehot,
num_boxes,
alpha=self.focal_alpha,
gamma=2,
)
* src_logits.shape[1]
)
losses = {"loss_ce": loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs["pred_logits"]
device = pred_logits.device
tgt_lengths = torch.as_tensor(
[len(v["labels"]) for v in targets], device=device
)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {"cardinality_error": card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
"""
assert "pred_boxes" in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat(
[t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes),
)
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
def loss_masks(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
assert "pred_masks" in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(
[t["masks"] for t in targets]
).decompose()
target_masks = target_masks.to(src_masks)
src_masks = src_masks[src_idx]
# upsample predictions to the target size
src_masks = interpolate(
src_masks[:, None],
size=target_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks[tgt_idx].flatten(1)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat(
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
)
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat(
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
)
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
"labels": self.loss_labels,
"cardinality": self.loss_cardinality,
"boxes": self.loss_boxes,
"masks": self.loss_masks,
}
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {
k: v
for k, v in outputs.items()
if k != "aux_outputs" and k != "enc_outputs"
}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor(
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
kwargs = {}
losses.update(
self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
)
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == "masks":
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(
loss, aux_outputs, targets, indices, num_boxes, **kwargs
)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
if "enc_outputs" in outputs:
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
if loss == "masks":
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(
loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
)
l_dict = {k + f"_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(self, topk=100):
super().__init__()
self.topk = topk
print("topk for eval:", self.topk)
@torch.no_grad()
def forward(self, outputs, target_sizes):
""" Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1), self.topk, dim=1
)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = [
{"scores": s, "labels": l, "boxes": b}
for s, l, b in zip(scores, labels, boxes)
]
return results
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def build(args):
num_classes = 20 if args.dataset_file != "coco" else 91
if args.dataset_file == "coco_panoptic":
num_classes = 250
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_deforamble_transformer(args)
model = DeformableDETR(
backbone,
transformer,
num_classes=num_classes,
num_feature_levels=args.num_feature_levels,
aux_loss=args.aux_loss,
with_box_refine=args.with_box_refine,
two_stage=args.two_stage,
num_queries_one2one=args.num_queries_one2one,
num_queries_one2many=args.num_queries_one2many,
mixed_selection=args.mixed_selection,
)
if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
weight_dict["loss_giou"] = args.giou_loss_coef
if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
new_dict = dict()
for key, value in weight_dict.items():
new_dict[key] = value
new_dict[key + "_one2many"] = value
weight_dict = new_dict
losses = ["labels", "boxes", "cardinality"]
if args.masks:
losses += ["masks"]
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
criterion = SetCriterion(
num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha
)
criterion.to(device)
postprocessors = {"bbox": PostProcess(topk=args.topk)}
if args.masks:
postprocessors["segm"] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(
is_thing_map, threshold=0.85
)
return model, criterion, postprocessors
def build_test(args):
num_classes = 20 if args.dataset_file != "coco" else 91
backbone = build_backbone(args)
transformer = build_deforamble_transformer(args)
model = DeformableDETR(
backbone,
transformer,
num_classes=num_classes,
num_feature_levels=args.num_feature_levels,
aux_loss=True,
with_box_refine=True,
two_stage=True,
num_queries_one2one=args.num_queries_one2one,
num_queries_one2many=args.num_queries_one2many,
mixed_selection=args.mixed_selection,
)
postprocessors = {"bbox": PostProcess(topk=args.topk)}
return model, postprocessors
\ No newline at end of file
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import copy
from typing import Optional, List
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import torch.utils.checkpoint as checkpoint
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
from util.misc import inverse_sigmoid
from models.ops.modules import MSDeformAttn
class DeformableTransformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation="relu",
return_intermediate_dec=False,
num_feature_levels=4,
dec_n_points=4,
enc_n_points=4,
two_stage=False,
two_stage_num_proposals=300,
look_forward_twice=False,
mixed_selection=False,
use_checkpoint=False,
):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.two_stage = two_stage
self.two_stage_num_proposals = two_stage_num_proposals
encoder_layer = DeformableTransformerEncoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
enc_n_points,
)
self.encoder = DeformableTransformerEncoder(
encoder_layer, num_encoder_layers, use_checkpoint
)
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
)
self.decoder = DeformableTransformerDecoder(
decoder_layer,
num_decoder_layers,
return_intermediate_dec,
look_forward_twice,
use_checkpoint,
)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
if two_stage:
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
self.pos_trans_norm = nn.LayerNorm(d_model * 2)
else:
self.reference_points = nn.Linear(d_model, 2)
self.mixed_selection = mixed_selection
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if not self.two_stage:
xavier_uniform_(self.reference_points.weight.data, gain=1.0)
constant_(self.reference_points.bias.data, 0.0)
normal_(self.level_embed)
def get_proposal_pos_embed(self, proposals):
num_pos_feats = 128
temperature = 10000
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device
)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack(
(pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
).flatten(2)
return pos
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
N_, H_, W_, 1
)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H_ - 1, H_, dtype=torch.float32, device=memory.device
),
torch.linspace(
0, W_ - 1, W_, dtype=torch.float32, device=memory.device
),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
N_, 1, 1, 2
)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += H_ * W_
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)
).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
output_memory = memory
output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0)
)
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, srcs, masks, pos_embeds, query_embed=None, self_attn_mask=None):
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(
src_flatten,
spatial_shapes,
level_start_index,
valid_ratios,
lvl_pos_embed_flatten,
mask_flatten,
)
# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
# hack implementation for two-stage Deformable DETR
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
output_memory
)
enc_outputs_coord_unact = (
self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ output_proposals
)
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
)
if not self.mixed_selection:
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
else:
# query_embed here is the content embed for deformable DETR
tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
query_embed, _ = torch.split(pos_trans_out, c, dim=2)
else:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_embed).sigmoid()
init_reference_out = reference_points
# decoder
hs, inter_references = self.decoder(
tgt,
reference_points,
memory,
spatial_shapes,
level_start_index,
valid_ratios,
query_embed,
mask_flatten,
self_attn_mask,
)
inter_references_out = inter_references
if self.two_stage:
return (
hs,
init_reference_out,
inter_references_out,
enc_outputs_class,
enc_outputs_coord_unact,
)
return hs, init_reference_out, inter_references_out, None, None
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask=None,
):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos),
reference_points,
src,
spatial_shapes,
level_start_index,
padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, use_checkpoint=False):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.use_checkpoint = use_checkpoint
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self,
src,
spatial_shapes,
level_start_index,
valid_ratios,
pos=None,
padding_mask=None,
):
output = src
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
for _, layer in enumerate(self.layers):
if self.use_checkpoint:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask,
)
else:
output = layer(
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask,
)
return output
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self,
tgt,
query_pos,
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask=None,
self_attn_mask=None,
):
# self attention
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(
q.transpose(0, 1),
k.transpose(0, 1),
tgt.transpose(0, 1),
attn_mask=self_attn_mask,
)[0].transpose(0, 1)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# cross attention
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos),
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask,
)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
class DeformableTransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
return_intermediate=False,
look_forward_twice=False,
use_checkpoint=False,
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.look_forward_twice = look_forward_twice
self.use_checkpoint = use_checkpoint
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
self.bbox_embed = None
self.class_embed = None
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self,
tgt,
reference_points,
src,
src_spatial_shapes,
src_level_start_index,
src_valid_ratios,
query_pos=None,
src_padding_mask=None,
self_attn_mask=None,
):
output = tgt
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None]
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
)
else:
assert reference_points.shape[-1] == 2
reference_points_input = (
reference_points[:, :, None] * src_valid_ratios[:, None]
)
if self.use_checkpoint:
output = checkpoint.checkpoint(
layer,
output,
query_pos,
reference_points_input,
src,
src_spatial_shapes,
src_level_start_index,
src_padding_mask,
self_attn_mask,
)
else:
output = layer(
output,
query_pos,
reference_points_input,
src,
src_spatial_shapes,
src_level_start_index,
src_padding_mask,
self_attn_mask,
)
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
reference_points
)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(
new_reference_points
if self.look_forward_twice
else reference_points
)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def build_deforamble_transformer(args):
return DeformableTransformer(
d_model=args.hidden_dim,
nhead=args.nheads,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=args.num_feature_levels,
dec_n_points=args.dec_n_points,
enc_n_points=args.enc_n_points,
two_stage=args.two_stage,
two_stage_num_proposals=args.num_queries_one2one + args.num_queries_one2many,
mixed_selection=args.mixed_selection,
look_forward_twice=args.look_forward_twice,
use_checkpoint=args.use_checkpoint,
)
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Modules to compute the matching cost and solve the corresponding LSAP.
"""
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(
self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1
):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert (
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
), "all costs cant be 0"
def forward(self, outputs, targets):
""" Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
with torch.no_grad():
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
out_bbox = outputs["pred_boxes"].flatten(
0, 1
) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (
(1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
)
pos_cost_class = (
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
)
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Final cost matrix
C = (
self.cost_bbox * cost_bbox
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
]
return [
(
torch.as_tensor(i, dtype=torch.int64),
torch.as_tensor(j, dtype=torch.int64),
)
for i, j in indices
]
def build_matcher(args):
return HungarianMatcher(
cost_class=args.set_cost_class,
cost_bbox=args.set_cost_bbox,
cost_giou=args.set_cost_giou,
)
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(
value, value_spatial_shapes, sampling_locations, attention_weights
):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = (
value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(
N_ * M_, 1, Lq_, L_ * P_
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(N_, M_ * D_, Lq_)
)
return output.transpose(1, 2).contiguous()
#!/usr/bin/env bash
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
python setup.py build install
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from ..functions import MSDeformAttnFunction
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
)
return (n & (n - 1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError(
"d_model must be divisible by n_heads, but got {} and {}".format(
d_model, n_heads
)
)
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn(
"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation."
)
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
2.0 * math.pi / self.n_heads
)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self,
query,
reference_points,
input_flatten,
input_spatial_shapes,
input_level_start_index,
input_padding_mask=None,
):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(
N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(query).view(
N, Len_q, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
N, Len_q, self.n_heads, self.n_levels, self.n_points
)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets
/ self.n_points
* reference_points[:, :, None, :, None, 2:]
* 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
output = MSDeformAttnFunction.apply(
value,
input_spatial_shapes,
input_level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import os
import glob
import torch
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if torch.cuda.is_available():
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
raise NotImplementedError('Cuda is not availabel')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"MultiScaleDeformableAttention",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="MultiScaleDeformableAttention",
version="1.0",
author="Weijie Su",
url="https://github.com/fundamentalvision/Deformable-DETR",
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
packages=find_packages(exclude=("configs", "tests",)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment