Commit e0a11e60 authored by luopl's avatar luopl
Browse files

init

parents
# ========================================
# Modified by Shoufa Chen
# ========================================
# Modified by Peize Sun, Rufeng Zhang
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
__all__ = ["DiffusionDetDatasetMapper"]
def build_transform_gen(cfg, is_train):
"""
Create a list of :class:`TransformGen` from config.
Returns:
list[TransformGen]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
if sample_style == "range":
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
logger = logging.getLogger(__name__)
tfm_gens = []
if is_train:
tfm_gens.append(T.RandomFlip())
# ResizeShortestEdge
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
if is_train:
logger.info("TransformGens used in training: " + str(tfm_gens))
return tfm_gens
class DiffusionDetDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by DiffusionDet.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""
def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = [
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
]
else:
self.crop_gen = None
self.tfm_gens = build_transform_gen(cfg, is_train)
logging.getLogger(__name__).info(
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
)
self.img_format = cfg.INPUT.FORMAT
self.is_train = is_train
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
if self.crop_gen is None:
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
else:
if np.random.rand() > 0.5:
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
else:
image, transforms = T.apply_transform_gens(
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
)
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if not self.is_train:
# USER: Modify this if you want to keep them for some reason.
dataset_dict.pop("annotations", None)
return dataset_dict
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
anno.pop("segmentation", None)
anno.pop("keypoints", None)
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(obj, transforms, image_shape)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(annos, image_shape)
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
# ========================================
# Modified by Shoufa Chen
# ========================================
# Modified by Peize Sun, Rufeng Zhang
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import random
from typing import List
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import batched_nms
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances
from .loss import SetCriterionDynamicK, HungarianMatcherDynamicK
from .head import DynamicHead
from .util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from .util.misc import nested_tensor_from_tensor_list
__all__ = ["DiffusionDet"]
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def extract(a, t, x_shape):
"""extract the appropriate t index for a batch of indices"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@META_ARCH_REGISTRY.register()
class DiffusionDet(nn.Module):
"""
Implement DiffusionDet
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
self.num_classes = cfg.MODEL.DiffusionDet.NUM_CLASSES
self.num_proposals = cfg.MODEL.DiffusionDet.NUM_PROPOSALS
self.hidden_dim = cfg.MODEL.DiffusionDet.HIDDEN_DIM
self.num_heads = cfg.MODEL.DiffusionDet.NUM_HEADS
# Build Backbone.
self.backbone = build_backbone(cfg)
self.size_divisibility = self.backbone.size_divisibility
# build diffusion
timesteps = 1000
sampling_timesteps = cfg.MODEL.DiffusionDet.SAMPLE_STEP
self.objective = 'pred_x0'
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sampling_timesteps = default(sampling_timesteps, timesteps)
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = 1.
self.self_condition = False
self.scale = cfg.MODEL.DiffusionDet.SNR_SCALE
self.box_renewal = True
self.use_ensemble = True
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# Build Dynamic Head.
self.head = DynamicHead(cfg=cfg, roi_input_shape=self.backbone.output_shape())
# Loss parameters:
class_weight = cfg.MODEL.DiffusionDet.CLASS_WEIGHT
giou_weight = cfg.MODEL.DiffusionDet.GIOU_WEIGHT
l1_weight = cfg.MODEL.DiffusionDet.L1_WEIGHT
no_object_weight = cfg.MODEL.DiffusionDet.NO_OBJECT_WEIGHT
self.deep_supervision = cfg.MODEL.DiffusionDet.DEEP_SUPERVISION
self.use_focal = cfg.MODEL.DiffusionDet.USE_FOCAL
self.use_fed_loss = cfg.MODEL.DiffusionDet.USE_FED_LOSS
self.use_nms = cfg.MODEL.DiffusionDet.USE_NMS
# Build Criterion.
matcher = HungarianMatcherDynamicK(
cfg=cfg, cost_class=class_weight, cost_bbox=l1_weight, cost_giou=giou_weight, use_focal=self.use_focal
)
weight_dict = {"loss_ce": class_weight, "loss_bbox": l1_weight, "loss_giou": giou_weight}
if self.deep_supervision:
aux_weight_dict = {}
for i in range(self.num_heads - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "boxes"]
self.criterion = SetCriterionDynamicK(
cfg=cfg, num_classes=self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight,
losses=losses, use_focal=self.use_focal,)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def model_predictions(self, backbone_feats, images_whwh, x, t, x_self_cond=None, clip_x_start=False):
x_boxes = torch.clamp(x, min=-1 * self.scale, max=self.scale)
x_boxes = ((x_boxes / self.scale) + 1) / 2
x_boxes = box_cxcywh_to_xyxy(x_boxes)
x_boxes = x_boxes * images_whwh[:, None, :]
outputs_class, outputs_coord = self.head(backbone_feats, x_boxes, t, None)
x_start = outputs_coord[-1] # (batch, num_proposals, 4) predict boxes: absolute coordinates (x1, y1, x2, y2)
x_start = x_start / images_whwh[:, None, :]
x_start = box_xyxy_to_cxcywh(x_start)
x_start = (x_start * 2 - 1.) * self.scale
x_start = torch.clamp(x_start, min=-1 * self.scale, max=self.scale)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start), outputs_class, outputs_coord
@torch.no_grad()
def ddim_sample(self, batched_inputs, backbone_feats, images_whwh, images, clip_denoised=True, do_postprocess=True):
batch = images_whwh.shape[0]
shape = (batch, self.num_proposals, 4)
total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device=self.device)
ensemble_score, ensemble_label, ensemble_coord = [], [], []
x_start = None
for time, time_next in time_pairs:
time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long)
self_cond = x_start if self.self_condition else None
preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond,
self_cond, clip_x_start=clip_denoised)
pred_noise, x_start = preds.pred_noise, preds.pred_x_start
if self.box_renewal: # filter
score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
threshold = 0.5
score_per_image = torch.sigmoid(score_per_image)
value, _ = torch.max(score_per_image, -1, keepdim=False)
keep_idx = value > threshold
num_remain = torch.sum(keep_idx)
pred_noise = pred_noise[:, keep_idx, :]
x_start = x_start[:, keep_idx, :]
img = img[:, keep_idx, :]
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
if self.box_renewal: # filter
# replenish with randn boxes
img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
outputs_coord[-1],
images.image_sizes)
ensemble_score.append(scores_per_image)
ensemble_label.append(labels_per_image)
ensemble_coord.append(box_pred_per_image)
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image = torch.cat(ensemble_coord, dim=0)
scores_per_image = torch.cat(ensemble_score, dim=0)
labels_per_image = torch.cat(ensemble_label, dim=0)
if self.use_nms:
keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
box_pred_per_image = box_pred_per_image[keep]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
result = Instances(images.image_sizes[0])
result.pred_boxes = Boxes(box_pred_per_image)
result.scores = scores_per_image
result.pred_classes = labels_per_image
results = [result]
else:
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
box_cls = output["pred_logits"]
box_pred = output["pred_boxes"]
results = self.inference(box_cls, box_pred, images.image_sizes)
if do_postprocess:
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
# forward diffusion
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def forward(self, batched_inputs, do_postprocess=True):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances: Instances
Other information that's included in the original dicts, such as:
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
"""
images, images_whwh = self.preprocess_image(batched_inputs)
if isinstance(images, (list, torch.Tensor)):
images = nested_tensor_from_tensor_list(images)
# Feature Extraction.
src = self.backbone(images.tensor)
features = list()
for f in self.in_features:
feature = src[f]
features.append(feature)
# Prepare Proposals.
if not self.training:
results = self.ddim_sample(batched_inputs, features, images_whwh, images)
return results
if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets, x_boxes, noises, t = self.prepare_targets(gt_instances)
t = t.squeeze(-1)
x_boxes = x_boxes * images_whwh[:, None, :]
outputs_class, outputs_coord = self.head(features, x_boxes, t, None)
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.deep_supervision:
output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
loss_dict = self.criterion(output, targets)
weight_dict = self.criterion.weight_dict
for k in loss_dict.keys():
if k in weight_dict:
loss_dict[k] *= weight_dict[k]
return loss_dict
def prepare_diffusion_repeat(self, gt_boxes):
"""
:param gt_boxes: (cx, cy, w, h), normalized
:param num_proposals:
"""
t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
noise = torch.randn(self.num_proposals, 4, device=self.device)
num_gt = gt_boxes.shape[0]
if not num_gt: # generate fake gt boxes if empty gt boxes
gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
num_gt = 1
num_repeat = self.num_proposals // num_gt # number of repeat except the last gt box in one image
repeat_tensor = [num_repeat] * (num_gt - self.num_proposals % num_gt) + [num_repeat + 1] * (
self.num_proposals % num_gt)
assert sum(repeat_tensor) == self.num_proposals
random.shuffle(repeat_tensor)
repeat_tensor = torch.tensor(repeat_tensor, device=self.device)
gt_boxes = (gt_boxes * 2. - 1.) * self.scale
x_start = torch.repeat_interleave(gt_boxes, repeat_tensor, dim=0)
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise)
x = torch.clamp(x, min=-1 * self.scale, max=self.scale)
x = ((x / self.scale) + 1) / 2.
diff_boxes = box_cxcywh_to_xyxy(x)
return diff_boxes, noise, t
def prepare_diffusion_concat(self, gt_boxes):
"""
:param gt_boxes: (cx, cy, w, h), normalized
:param num_proposals:
"""
t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
noise = torch.randn(self.num_proposals, 4, device=self.device)
num_gt = gt_boxes.shape[0]
if not num_gt: # generate fake gt boxes if empty gt boxes
gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
num_gt = 1
if num_gt < self.num_proposals:
box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
device=self.device) / 6. + 0.5 # 3sigma = 1/2 --> sigma: 1/6
box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
elif num_gt > self.num_proposals:
select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
random.shuffle(select_mask)
x_start = gt_boxes[select_mask]
else:
x_start = gt_boxes
x_start = (x_start * 2. - 1.) * self.scale
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise)
x = torch.clamp(x, min=-1 * self.scale, max=self.scale)
x = ((x / self.scale) + 1) / 2.
diff_boxes = box_cxcywh_to_xyxy(x)
return diff_boxes, noise, t
def prepare_targets(self, targets):
new_targets = []
diffused_boxes = []
noises = []
ts = []
for targets_per_image in targets:
target = {}
h, w = targets_per_image.image_size
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
gt_classes = targets_per_image.gt_classes
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)
diffused_boxes.append(d_boxes)
noises.append(d_noise)
ts.append(d_t)
target["labels"] = gt_classes.to(self.device)
target["boxes"] = gt_boxes.to(self.device)
target["boxes_xyxy"] = targets_per_image.gt_boxes.tensor.to(self.device)
target["image_size_xyxy"] = image_size_xyxy.to(self.device)
image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
target["image_size_xyxy_tgt"] = image_size_xyxy_tgt.to(self.device)
target["area"] = targets_per_image.gt_boxes.area().to(self.device)
new_targets.append(target)
return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
def inference(self, box_cls, box_pred, image_sizes):
"""
Arguments:
box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
The tensor predicts the classification probability for each proposal.
box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
The tensor predicts 4-vector (x,y,w,h) box
regression values for every proposal
image_sizes (List[torch.Size]): the input image sizes
Returns:
results (List[Instances]): a list of #images elements.
"""
assert len(box_cls) == len(image_sizes)
results = []
if self.use_focal or self.use_fed_loss:
scores = torch.sigmoid(box_cls)
labels = torch.arange(self.num_classes, device=self.device). \
unsqueeze(0).repeat(self.num_proposals, 1).flatten(0, 1)
for i, (scores_per_image, box_pred_per_image, image_size) in enumerate(zip(
scores, box_pred, image_sizes
)):
result = Instances(image_size)
scores_per_image, topk_indices = scores_per_image.flatten(0, 1).topk(self.num_proposals, sorted=False)
labels_per_image = labels[topk_indices]
box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(1, self.num_classes, 1).view(-1, 4)
box_pred_per_image = box_pred_per_image[topk_indices]
if self.use_ensemble and self.sampling_timesteps > 1:
return box_pred_per_image, scores_per_image, labels_per_image
if self.use_nms:
keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
box_pred_per_image = box_pred_per_image[keep]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
result.pred_boxes = Boxes(box_pred_per_image)
result.scores = scores_per_image
result.pred_classes = labels_per_image
results.append(result)
else:
# For each box we assign the best class or the second best if the best on is `no_object`.
scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip(
scores, labels, box_pred, image_sizes
)):
if self.use_ensemble and self.sampling_timesteps > 1:
return box_pred_per_image, scores_per_image, labels_per_image
if self.use_nms:
keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
box_pred_per_image = box_pred_per_image[keep]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
result = Instances(image_size)
result.pred_boxes = Boxes(box_pred_per_image)
result.scores = scores_per_image
result.pred_classes = labels_per_image
results.append(result)
return results
def preprocess_image(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
images = ImageList.from_tensors(images, self.size_divisibility)
images_whwh = list()
for bi in batched_inputs:
h, w = bi["image"].shape[-2:]
images_whwh.append(torch.tensor([w, h, w, h], dtype=torch.float32, device=self.device))
images_whwh = torch.stack(images_whwh)
return images, images_whwh
# ========================================
# Modified by Shoufa Chen
# ========================================
# Modified by Peize Sun, Rufeng Zhang
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DiffusionDet Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes
_DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
def __init__(self, input_dim, output_dim):
super().__init__()
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.dense(x)
class DynamicHead(nn.Module):
def __init__(self, cfg, roi_input_shape):
super().__init__()
# Build RoI.
box_pooler = self._init_box_pooler(cfg, roi_input_shape)
self.box_pooler = box_pooler
# Build heads.
num_classes = cfg.MODEL.DiffusionDet.NUM_CLASSES
d_model = cfg.MODEL.DiffusionDet.HIDDEN_DIM
dim_feedforward = cfg.MODEL.DiffusionDet.DIM_FEEDFORWARD
nhead = cfg.MODEL.DiffusionDet.NHEADS
dropout = cfg.MODEL.DiffusionDet.DROPOUT
activation = cfg.MODEL.DiffusionDet.ACTIVATION
num_heads = cfg.MODEL.DiffusionDet.NUM_HEADS
rcnn_head = RCNNHead(cfg, d_model, num_classes, dim_feedforward, nhead, dropout, activation)
self.head_series = _get_clones(rcnn_head, num_heads)
self.num_heads = num_heads
self.return_intermediate = cfg.MODEL.DiffusionDet.DEEP_SUPERVISION
# Gaussian random feature embedding layer for time
self.d_model = d_model
time_dim = d_model * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(d_model),
nn.Linear(d_model, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
# Init parameters.
self.use_focal = cfg.MODEL.DiffusionDet.USE_FOCAL
self.use_fed_loss = cfg.MODEL.DiffusionDet.USE_FED_LOSS
self.num_classes = num_classes
if self.use_focal or self.use_fed_loss:
prior_prob = cfg.MODEL.DiffusionDet.PRIOR_PROB
self.bias_value = -math.log((1 - prior_prob) / prior_prob)
self._reset_parameters()
def _reset_parameters(self):
# init all parameters.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# initialize the bias for focal loss and fed loss.
if self.use_focal or self.use_fed_loss:
if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1:
nn.init.constant_(p, self.bias_value)
@staticmethod
def _init_box_pooler(cfg, input_shape):
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
# If StandardROIHeads is applied on multiple feature maps (as in FPN),
# then we share the same predictors and therefore the channel counts must be the same
in_channels = [input_shape[f].channels for f in in_features]
# Check all channel counts are equal
assert len(set(in_channels)) == 1, in_channels
box_pooler = ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
)
return box_pooler
def forward(self, features, init_bboxes, t, init_features):
# assert t shape (batch_size)
time = self.time_mlp(t)
inter_class_logits = []
inter_pred_bboxes = []
bs = len(features[0])
bboxes = init_bboxes
num_boxes = bboxes.shape[1]
if init_features is not None:
init_features = init_features[None].repeat(1, bs, 1)
proposal_features = init_features.clone()
else:
proposal_features = None
for head_idx, rcnn_head in enumerate(self.head_series):
class_logits, pred_bboxes, proposal_features = rcnn_head(features, bboxes, proposal_features, self.box_pooler, time)
if self.return_intermediate:
inter_class_logits.append(class_logits)
inter_pred_bboxes.append(pred_bboxes)
bboxes = pred_bboxes.detach()
if self.return_intermediate:
return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes)
return class_logits[None], pred_bboxes[None]
class RCNNHead(nn.Module):
def __init__(self, cfg, d_model, num_classes, dim_feedforward=2048, nhead=8, dropout=0.1, activation="relu",
scale_clamp: float = _DEFAULT_SCALE_CLAMP, bbox_weights=(2.0, 2.0, 1.0, 1.0)):
super().__init__()
self.d_model = d_model
# dynamic.
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.inst_interact = DynamicConv(cfg)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
# block time mlp
self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(d_model * 4, d_model * 2))
# cls.
num_cls = cfg.MODEL.DiffusionDet.NUM_CLS
cls_module = list()
for _ in range(num_cls):
cls_module.append(nn.Linear(d_model, d_model, False))
cls_module.append(nn.LayerNorm(d_model))
cls_module.append(nn.ReLU(inplace=True))
self.cls_module = nn.ModuleList(cls_module)
# reg.
num_reg = cfg.MODEL.DiffusionDet.NUM_REG
reg_module = list()
for _ in range(num_reg):
reg_module.append(nn.Linear(d_model, d_model, False))
reg_module.append(nn.LayerNorm(d_model))
reg_module.append(nn.ReLU(inplace=True))
self.reg_module = nn.ModuleList(reg_module)
# pred.
self.use_focal = cfg.MODEL.DiffusionDet.USE_FOCAL
self.use_fed_loss = cfg.MODEL.DiffusionDet.USE_FED_LOSS
if self.use_focal or self.use_fed_loss:
self.class_logits = nn.Linear(d_model, num_classes)
else:
self.class_logits = nn.Linear(d_model, num_classes + 1)
self.bboxes_delta = nn.Linear(d_model, 4)
self.scale_clamp = scale_clamp
self.bbox_weights = bbox_weights
def forward(self, features, bboxes, pro_features, pooler, time_emb):
"""
:param bboxes: (N, nr_boxes, 4)
:param pro_features: (N, nr_boxes, d_model)
"""
N, nr_boxes = bboxes.shape[:2]
# roi_feature.
proposal_boxes = list()
for b in range(N):
proposal_boxes.append(Boxes(bboxes[b]))
roi_features = pooler(features, proposal_boxes)
if pro_features is None:
pro_features = roi_features.view(N, nr_boxes, self.d_model, -1).mean(-1)
roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1)
# self_att.
pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2)
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
pro_features = pro_features + self.dropout1(pro_features2)
pro_features = self.norm1(pro_features)
# inst_interact.
pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model)
pro_features2 = self.inst_interact(pro_features, roi_features)
pro_features = pro_features + self.dropout2(pro_features2)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features)
fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1)
scale_shift = self.block_time_mlp(time_emb)
scale_shift = torch.repeat_interleave(scale_shift, nr_boxes, dim=0)
scale, shift = scale_shift.chunk(2, dim=1)
fc_feature = fc_feature * (scale + 1) + shift
cls_feature = fc_feature.clone()
reg_feature = fc_feature.clone()
for cls_layer in self.cls_module:
cls_feature = cls_layer(cls_feature)
for reg_layer in self.reg_module:
reg_feature = reg_layer(reg_feature)
class_logits = self.class_logits(cls_feature)
bboxes_deltas = self.bboxes_delta(reg_feature)
pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features
def apply_deltas(self, deltas, boxes):
"""
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
Args:
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
deltas[i] represents k potentially different class-specific
box transformations for the single box boxes[i].
boxes (Tensor): boxes to transform, of shape (N, 4)
"""
boxes = boxes.to(deltas.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.bbox_weights
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(deltas)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
return pred_boxes
class DynamicConv(nn.Module):
def __init__(self, cfg):
super().__init__()
self.hidden_dim = cfg.MODEL.DiffusionDet.HIDDEN_DIM
self.dim_dynamic = cfg.MODEL.DiffusionDet.DIM_DYNAMIC
self.num_dynamic = cfg.MODEL.DiffusionDet.NUM_DYNAMIC
self.num_params = self.hidden_dim * self.dim_dynamic
self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
self.norm1 = nn.LayerNorm(self.dim_dynamic)
self.norm2 = nn.LayerNorm(self.hidden_dim)
self.activation = nn.ReLU(inplace=True)
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
num_output = self.hidden_dim * pooler_resolution ** 2
self.out_layer = nn.Linear(num_output, self.hidden_dim)
self.norm3 = nn.LayerNorm(self.hidden_dim)
def forward(self, pro_features, roi_features):
'''
pro_features: (1, N * nr_boxes, self.d_model)
roi_features: (49, N * nr_boxes, self.d_model)
'''
features = roi_features.permute(1, 0, 2)
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic)
param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim)
features = torch.bmm(features, param1)
features = self.norm1(features)
features = self.activation(features)
features = torch.bmm(features, param2)
features = self.norm2(features)
features = self.activation(features)
features = features.flatten(1)
features = self.out_layer(features)
features = self.norm3(features)
features = self.activation(features)
return features
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
# ========================================
# Modified by Shoufa Chen
# ========================================
# Modified by Peize Sun, Rufeng Zhang
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DiffusionDet model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
from fvcore.nn import sigmoid_focal_loss_jit
import torchvision.ops as ops
from .util import box_ops
from .util.misc import get_world_size, is_dist_avail_and_initialized
from .util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou
class SetCriterionDynamicK(nn.Module):
""" This class computes the loss for DiffusionDet.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, cfg, num_classes, matcher, weight_dict, eos_coef, losses, use_focal):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.cfg = cfg
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
self.losses = losses
self.use_focal = use_focal
self.use_fed_loss = cfg.MODEL.DiffusionDet.USE_FED_LOSS
if self.use_fed_loss:
self.fed_loss_num_classes = 50
from detectron2.data.detection_utils import get_fed_loss_cls_weights
cls_weight_fun = lambda: get_fed_loss_cls_weights(dataset_names=cfg.DATASETS.TRAIN, freq_weight_power=cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER) # noqa
fed_loss_cls_weights = cls_weight_fun()
assert (
len(fed_loss_cls_weights) == self.num_classes
), "Please check the provided fed_loss_cls_weights. Their size should match num_classes"
self.register_buffer("fed_loss_cls_weights", fed_loss_cls_weights)
if self.use_focal:
self.focal_loss_alpha = cfg.MODEL.DiffusionDet.ALPHA
self.focal_loss_gamma = cfg.MODEL.DiffusionDet.GAMMA
else:
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer('empty_weight', empty_weight)
# copy-paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/roi_heads/fast_rcnn.py#L356
def get_fed_loss_classes(self, gt_classes, num_fed_loss_classes, num_classes, weight):
"""
Args:
gt_classes: a long tensor of shape R that contains the gt class label of each proposal.
num_fed_loss_classes: minimum number of classes to keep when calculating federated loss.
Will sample negative classes if number of unique gt_classes is smaller than this value.
num_classes: number of foreground classes
weight: probabilities used to sample negative classes
Returns:
Tensor:
classes to keep when calculating the federated loss, including both unique gt
classes and sampled negative classes.
"""
unique_gt_classes = torch.unique(gt_classes)
prob = unique_gt_classes.new_ones(num_classes + 1).float()
prob[-1] = 0
if len(unique_gt_classes) < num_fed_loss_classes:
prob[:num_classes] = weight.float().clone()
prob[unique_gt_classes] = 0
sampled_negative_classes = torch.multinomial(
prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False
)
fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes])
else:
fed_loss_classes = unique_gt_classes
return fed_loss_classes
def loss_labels(self, outputs, targets, indices, num_boxes, log=False):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
batch_size = len(targets)
# idx = self._get_src_permutation_idx(indices)
# target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
src_logits_list = []
target_classes_o_list = []
# target_classes[idx] = target_classes_o
for batch_idx in range(batch_size):
valid_query = indices[batch_idx][0]
gt_multi_idx = indices[batch_idx][1]
if len(gt_multi_idx) == 0:
continue
bz_src_logits = src_logits[batch_idx]
target_classes_o = targets[batch_idx]["labels"]
target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx]
src_logits_list.append(bz_src_logits[valid_query])
target_classes_o_list.append(target_classes_o[gt_multi_idx])
if self.use_focal or self.use_fed_loss:
num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1
target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1],
dtype=src_logits.dtype, layout=src_logits.layout,
device=src_logits.device)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
gt_classes = torch.argmax(target_classes_onehot, dim=-1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
src_logits = src_logits.flatten(0, 1)
target_classes_onehot = target_classes_onehot.flatten(0, 1)
if self.use_focal:
cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="none")
else:
cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none")
if self.use_fed_loss:
K = self.num_classes
N = src_logits.shape[0]
fed_loss_classes = self.get_fed_loss_classes(
gt_classes,
num_fed_loss_classes=self.fed_loss_num_classes,
num_classes=K,
weight=self.fed_loss_cls_weights,
)
fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1)
fed_loss_classes_mask[fed_loss_classes] = 1
fed_loss_classes_mask = fed_loss_classes_mask[:K]
weight = fed_loss_classes_mask.view(1, K).expand(N, K).float()
loss_ce = torch.sum(cls_loss * weight) / num_boxes
else:
loss_ce = torch.sum(cls_loss) / num_boxes
losses = {'loss_ce': loss_ce}
else:
raise NotImplementedError
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert 'pred_boxes' in outputs
# idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes']
batch_size = len(targets)
pred_box_list = []
pred_norm_box_list = []
tgt_box_list = []
tgt_box_xyxy_list = []
for batch_idx in range(batch_size):
valid_query = indices[batch_idx][0]
gt_multi_idx = indices[batch_idx][1]
if len(gt_multi_idx) == 0:
continue
bz_image_whwh = targets[batch_idx]['image_size_xyxy']
bz_src_boxes = src_boxes[batch_idx]
bz_target_boxes = targets[batch_idx]["boxes"] # normalized (cx, cy, w, h)
bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] # absolute (x1, y1, x2, y2)
pred_box_list.append(bz_src_boxes[valid_query])
pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) # normalize (x1, y1, x2, y2)
tgt_box_list.append(bz_target_boxes[gt_multi_idx])
tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx])
if len(pred_box_list) != 0:
src_boxes = torch.cat(pred_box_list)
src_boxes_norm = torch.cat(pred_norm_box_list) # normalized (x1, y1, x2, y2)
target_boxes = torch.cat(tgt_box_list)
target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list)
num_boxes = src_boxes.shape[0]
losses = {}
# require normalized (x1, y1, x2, y2)
loss_bbox = F.l1_loss(src_boxes_norm, box_cxcywh_to_xyxy(target_boxes), reduction='none')
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
# loss_giou = giou_loss(box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy))
losses['loss_giou'] = loss_giou.sum() / num_boxes
else:
losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0,
'loss_giou': outputs['pred_boxes'].sum() * 0}
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'boxes': self.loss_boxes,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
indices, _ = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices, _ = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
class HungarianMatcherDynamicK(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cfg, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, cost_mask: float = 1, use_focal: bool = False):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.use_focal = use_focal
self.use_fed_loss = cfg.MODEL.DiffusionDet.USE_FED_LOSS
self.ota_k = cfg.MODEL.DiffusionDet.OTA_K
if self.use_focal:
self.focal_loss_alpha = cfg.MODEL.DiffusionDet.ALPHA
self.focal_loss_gamma = cfg.MODEL.DiffusionDet.GAMMA
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
def forward(self, outputs, targets):
""" simOTA for detr"""
with torch.no_grad():
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
if self.use_focal or self.use_fed_loss:
out_prob = outputs["pred_logits"].sigmoid() # [batch_size, num_queries, num_classes]
out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
else:
out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes]
out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
indices = []
matched_ids = []
assert bs == len(targets)
for batch_idx in range(bs):
bz_boxes = out_bbox[batch_idx] # [num_proposals, 4]
bz_out_prob = out_prob[batch_idx]
bz_tgt_ids = targets[batch_idx]["labels"]
num_insts = len(bz_tgt_ids)
if num_insts == 0: # empty object in key frame
non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0
indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob))
matched_qidx = torch.arange(0, 0).to(bz_out_prob)
indices.append(indices_batchi)
matched_ids.append(matched_qidx)
continue
bz_gtboxs = targets[batch_idx]['boxes'] # [num_gt, 4] normalized (cx, xy, w, h)
bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy']
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
box_xyxy_to_cxcywh(bz_boxes), # absolute (cx, cy, w, h)
box_xyxy_to_cxcywh(bz_gtboxs_abs_xyxy), # absolute (cx, cy, w, h)
expanded_strides=32
)
pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
# Compute the classification cost.
if self.use_focal:
alpha = self.focal_loss_alpha
gamma = self.focal_loss_gamma
neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log())
cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
elif self.use_fed_loss:
# focal loss degenerates to naive one
neg_cost_class = (-(1 - bz_out_prob + 1e-8).log())
pos_cost_class = (-(bz_out_prob + 1e-8).log())
cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
else:
cost_class = -bz_out_prob[:, bz_tgt_ids]
# Compute the L1 cost between boxes
# image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets])
# image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1)
# image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
bz_image_size_out = targets[batch_idx]['image_size_xyxy']
bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt']
bz_out_bbox_ = bz_boxes / bz_image_size_out # normalize (x1, y1, x2, y2)
bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt # normalize (x1, y1, x2, y2)
cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1)
cost_giou = -generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
# Final cost matrix
cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * (~is_in_boxes_and_center)
# cost = (cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) # [num_query,num_gt]
cost[~fg_mask] = cost[~fg_mask] + 10000.0
# if bz_gtboxs.shape[0]>0:
indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0])
indices.append(indices_batchi)
matched_ids.append(matched_qidx)
return indices, matched_ids
def get_in_boxes_info(self, boxes, target_gts, expanded_strides):
xy_target_gts = box_cxcywh_to_xyxy(target_gts) # (x1, y1, x2, y2)
anchor_center_x = boxes[:, 0].unsqueeze(1)
anchor_center_y = boxes[:, 1].unsqueeze(1)
# whether the center of each anchor is inside a gt box
b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0)
b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0)
b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0)
b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0)
# (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] ,
is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query]
# in fixed center
center_radius = 2.5
# Modified to self-adapted sampling --- the center size depends on the size of the gt boxes
# https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212
b_l = anchor_center_x > (target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
b_r = anchor_center_x < (target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
b_t = anchor_center_y > (target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
b_b = anchor_center_y < (target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
is_in_centers_all = is_in_centers.sum(1) > 0
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = (is_in_boxes & is_in_centers)
return is_in_boxes_anchor, is_in_boxes_and_center
def dynamic_k_matching(self, cost, pair_wise_ious, num_gt):
matching_matrix = torch.zeros_like(cost) # [300,num_gt]
ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = self.ota_k
# Take the sum of the predicted value and the top 10 iou of gt with the largest iou as dynamic_k
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0)
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
anchor_matching_gt = matching_matrix.sum(1)
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
matching_matrix[anchor_matching_gt > 1] *= 0
matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1
while (matching_matrix.sum(0) == 0).any():
num_zero_gt = (matching_matrix.sum(0) == 0).sum()
matched_query_id = matching_matrix.sum(1) > 0
cost[matched_query_id] += 100000.0
unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
for gt_idx in unmatch_id:
pos_idx = torch.argmin(cost[:, gt_idx])
matching_matrix[:, gt_idx][pos_idx] = 1.0
if (matching_matrix.sum(1) > 1).sum() > 0: # If a query matches more than one gt
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1],
dim=1) # find gt for these queries with minimal cost
matching_matrix[anchor_matching_gt > 1] *= 0 # reset mapping relationship
matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 # keep gt with minimal cost
assert not (matching_matrix.sum(0) == 0).any()
selected_query = matching_matrix.sum(1) > 0
gt_indices = matching_matrix[selected_query].max(1)[1]
assert selected_query.sum() == len(gt_indices)
cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf')
matched_query_id = torch.min(cost, dim=0)[1]
return (selected_query, gt_indices), matched_query_id
# Copyright (c) Facebook, Inc. and its affiliates.
import atexit
import bisect
import multiprocessing as mp
from collections import deque
import cv2
import torch
from detectron2.data import MetadataCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode, Visualizer
class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
"""
Args:
cfg (CfgNode):
instance_mode (ColorMode):
parallel (bool): whether to run the model in different processes from visualization.
Useful since the visualization logic can be slow.
"""
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.parallel = parallel
if parallel:
num_gpu = torch.cuda.device_count()
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
else:
self.predictor = DefaultPredictor(cfg)
self.threshold = cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST # workaround
def run_on_image(self, image):
"""
Args:
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
This is the format used by OpenCV.
Returns:
predictions (dict): the output of the model.
vis_output (VisImage): the visualized image output.
"""
vis_output = None
predictions = self.predictor(image)
# Filter
instances = predictions['instances']
new_instances = instances[instances.scores > self.threshold]
predictions = {'instances': new_instances}
# Convert image from OpenCV BGR format to Matplotlib RGB format.
image = image[:, :, ::-1]
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
if "panoptic_seg" in predictions:
panoptic_seg, segments_info = predictions["panoptic_seg"]
vis_output = visualizer.draw_panoptic_seg_predictions(
panoptic_seg.to(self.cpu_device), segments_info
)
else:
if "sem_seg" in predictions:
vis_output = visualizer.draw_sem_seg(
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
)
if "instances" in predictions:
instances = predictions["instances"].to(self.cpu_device)
vis_output = visualizer.draw_instance_predictions(predictions=instances)
return predictions, vis_output
def _frame_from_video(self, video):
while video.isOpened():
success, frame = video.read()
if success:
yield frame
else:
break
def run_on_video(self, video):
"""
Visualizes predictions on frames of the input video.
Args:
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
either a webcam or a video file.
Yields:
ndarray: BGR visualizations of each video frame.
"""
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
def process_predictions(frame, predictions):
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if "panoptic_seg" in predictions:
panoptic_seg, segments_info = predictions["panoptic_seg"]
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
frame, panoptic_seg.to(self.cpu_device), segments_info
)
elif "instances" in predictions:
predictions = predictions["instances"].to(self.cpu_device)
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
elif "sem_seg" in predictions:
vis_frame = video_visualizer.draw_sem_seg(
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
)
# Converts Matplotlib RGB format to OpenCV BGR format
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
return vis_frame
frame_gen = self._frame_from_video(video)
if self.parallel:
buffer_size = self.predictor.default_buffer_size
frame_data = deque()
for cnt, frame in enumerate(frame_gen):
frame_data.append(frame)
self.predictor.put(frame)
if cnt >= buffer_size:
frame = frame_data.popleft()
predictions = self.predictor.get()
yield process_predictions(frame, predictions)
while len(frame_data):
frame = frame_data.popleft()
predictions = self.predictor.get()
yield process_predictions(frame, predictions)
else:
for frame in frame_gen:
yield process_predictions(frame, self.predictor(frame))
class AsyncPredictor:
"""
A predictor that runs the model asynchronously, possibly on >1 GPUs.
Because rendering the visualization takes considerably amount of time,
this helps improve throughput a little bit when rendering videos.
"""
class _StopToken:
pass
class _PredictWorker(mp.Process):
def __init__(self, cfg, task_queue, result_queue):
self.cfg = cfg
self.task_queue = task_queue
self.result_queue = result_queue
super().__init__()
def run(self):
predictor = DefaultPredictor(self.cfg)
while True:
task = self.task_queue.get()
if isinstance(task, AsyncPredictor._StopToken):
break
idx, data = task
result = predictor(data)
self.result_queue.put((idx, result))
def __init__(self, cfg, num_gpus: int = 1):
"""
Args:
cfg (CfgNode):
num_gpus (int): if 0, will run on CPU
"""
num_workers = max(num_gpus, 1)
self.task_queue = mp.Queue(maxsize=num_workers * 3)
self.result_queue = mp.Queue(maxsize=num_workers * 3)
self.procs = []
for gpuid in range(max(num_gpus, 1)):
cfg = cfg.clone()
cfg.defrost()
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
self.procs.append(
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
)
self.put_idx = 0
self.get_idx = 0
self.result_rank = []
self.result_data = []
for p in self.procs:
p.start()
atexit.register(self.shutdown)
def put(self, image):
self.put_idx += 1
self.task_queue.put((self.put_idx, image))
def get(self):
self.get_idx += 1 # the index needed for this request
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
res = self.result_data[0]
del self.result_data[0], self.result_rank[0]
return res
while True:
# make sure the results are returned in the correct order
idx, res = self.result_queue.get()
if idx == self.get_idx:
return res
insert = bisect.bisect(self.result_rank, idx)
self.result_rank.insert(insert, idx)
self.result_data.insert(insert, res)
def __len__(self):
return self.put_idx - self.get_idx
def __call__(self, image):
self.put(image)
return self.get()
def shutdown(self):
for _ in self.procs:
self.task_queue.put(AsyncPredictor._StopToken())
@property
def default_buffer_size(self):
return len(self.procs) * 5
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu, Yutong Lin, Yixuan Wei
# --------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Xingyi Zhou from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import fvcore.nn.weight_init as weight_init
from detectron2.layers import ShapeSpec
from detectron2.modeling.backbone.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool
class LastLevelP6P7_P5(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7 from
C5 feature.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.num_levels = 2
self.in_feature = "p5"
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
weight_init.c2_xavier_fill(module)
def forward(self, c5):
p6 = self.p6(c5)
p7 = self.p7(F.relu(p6))
return [p6, p7]
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(Backbone):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self._freeze_stages()
self._out_features = ['swin{}'.format(i) for i in self.out_indices]
self._out_feature_channels = {
'swin{}'.format(i): self.embed_dim * 2 ** i for i in self.out_indices
}
self._out_feature_strides = {
'swin{}'.format(i): 2 ** (i + 2) for i in self.out_indices
}
self._size_devisibility = 32
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
# load_checkpoint(self, pretrained, strict=False)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
# outs = []
outs = {}
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
# outs.append(out)
outs['swin{}'.format(i)] = out
return outs
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
size2config = {
'T': {
'window_size': 7,
'embed_dim': 96,
'depth': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'drop_path_rate': 0.2,
'pretrained': 'models/swin_tiny_patch4_window7_224.pth'
},
'S': {
'window_size': 7,
'embed_dim': 96,
'depth': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'drop_path_rate': 0.2,
'pretrained': 'models/swin_small_patch4_window7_224.pth'
},
'B': {
'window_size': 7,
'embed_dim': 128,
'depth': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'drop_path_rate': 0.3,
'pretrained': 'models/swin_base_patch4_window7_224.pth'
},
'B-22k': {
'window_size': 7,
'embed_dim': 128,
'depth': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'drop_path_rate': 0.3,
'pretrained': 'models/swin_base_patch4_window7_224_22k.pth'
},
'B-22k-384': {
'window_size': 12,
'embed_dim': 128,
'depth': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'drop_path_rate': 0.3,
'pretrained': 'models/swin_base_patch4_window12_384_22k.pth'
},
'L-22k': {
'window_size': 7,
'embed_dim': 192,
'depth': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
'pretrained': 'models/swin_large_patch4_window7_224_22k.pth'
},
'L-22k-384': {
'window_size': 12,
'embed_dim': 192,
'depth': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
'pretrained': 'models/swin_large_patch4_window12_384_22k.pth'
}
}
@BACKBONE_REGISTRY.register()
def build_swintransformer_backbone(cfg, input_shape):
"""
"""
config = size2config[cfg.MODEL.SWIN.SIZE]
out_indices = cfg.MODEL.SWIN.OUT_FEATURES
model = SwinTransformer(
embed_dim=config['embed_dim'],
window_size=config['window_size'],
depths=config['depth'],
num_heads=config['num_heads'],
drop_path_rate=config['drop_path_rate'],
out_indices=out_indices,
frozen_stages=-1,
use_checkpoint=cfg.MODEL.SWIN.USE_CHECKPOINT
)
# print('Initializing', config['pretrained'])
model.init_weights(config['pretrained'])
return model
@BACKBONE_REGISTRY.register()
def build_swintransformer_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
"""
bottom_up = build_swintransformer_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
# top_block=LastLevelP6P7_P5(out_channels, out_channels),
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
@BACKBONE_REGISTRY.register()
def build_swintransformer_bifpn_backbone(cfg, input_shape: ShapeSpec):
"""
"""
bottom_up = build_swintransformer_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
backbone = BiFPN(
cfg=cfg,
bottom_up=bottom_up,
in_features=in_features,
out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS,
norm=cfg.MODEL.BIFPN.NORM,
num_levels=cfg.MODEL.BIFPN.NUM_LEVELS,
num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN,
separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV,
)
return backbone
\ No newline at end of file
# ========================================
# Modified by Shoufa Chen
# ========================================
# Modified by Rufeng Zhang, Peize Sun
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Megvii, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#
from itertools import count
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from detectron2.modeling import GeneralizedRCNNWithTTA, DatasetMapperTTA
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
from detectron2.structures import Instances, Boxes
class DiffusionDetWithTTA(GeneralizedRCNNWithTTA):
"""
A DiffusionDet with test-time augmentation enabled.
Its :meth:`__call__` method has the same interface as :meth:`DiffusionDet.forward`.
"""
def __init__(self, cfg, model, tta_mapper=None, batch_size=3):
"""
Args:
cfg (CfgNode):
model (DiffusionDet): a DiffusionDet to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
# fix the issue: cannot assign module before Module.__init__() call
nn.Module.__init__(self)
if isinstance(model, DistributedDataParallel):
model = model.module
self.cfg = cfg.clone()
self.model = model
if tta_mapper is None:
tta_mapper = DatasetMapperTTA(cfg)
self.tta_mapper = tta_mapper
self.batch_size = batch_size
# cvpods tta.
self.enable_cvpods_tta = cfg.TEST.AUG.CVPODS_TTA
self.enable_scale_filter = cfg.TEST.AUG.SCALE_FILTER
self.scale_ranges = cfg.TEST.AUG.SCALE_RANGES
self.max_detection = cfg.MODEL.DiffusionDet.NUM_PROPOSALS
def _batch_inference(self, batched_inputs, detected_instances=None):
"""
Execute inference on a list of inputs,
using batch size = self.batch_size, instead of the length of the list.
Inputs & outputs have the same format as :meth:`DiffusionDet.forward`
"""
if detected_instances is None:
detected_instances = [None] * len(batched_inputs)
factors = 2 if self.tta_mapper.flip else 1
if self.enable_scale_filter:
assert len(batched_inputs) == len(self.scale_ranges) * factors
outputs = []
inputs, instances = [], []
for idx, input, instance in zip(count(), batched_inputs, detected_instances):
inputs.append(input)
instances.append(instance)
if self.enable_cvpods_tta:
output = self.model.forward(inputs, do_postprocess=False)[0]
if self.enable_scale_filter:
pred_boxes = output.get("pred_boxes")
keep = self.filter_boxes(pred_boxes.tensor, *self.scale_ranges[idx // factors])
output = Instances(
image_size=output.image_size,
pred_boxes=Boxes(pred_boxes.tensor[keep]),
pred_classes=output.pred_classes[keep],
scores=output.scores[keep])
outputs.extend([output])
else:
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
outputs.extend(
self.model.forward(
inputs,
do_postprocess=False,
)
)
inputs, instances = [], []
return outputs
@staticmethod
def filter_boxes(boxes, min_scale, max_scale):
"""
boxes: (N, 4) shape
"""
# assert boxes.mode == "xyxy"
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]
keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
return keep
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
orig_shape = (input["height"], input["width"])
augmented_inputs, tfms = self._get_augmented_inputs(input)
# Detect boxes from all augmented versions
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
# merge all detected boxes to obtain final predictions for boxes
if self.enable_cvpods_tta:
merged_instances = self._merge_detections_cvpods_tta(all_boxes, all_scores, all_classes, orig_shape)
else:
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
return {"instances": merged_instances}
def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw):
# select from the union of all results
num_boxes = len(all_boxes)
num_classes = self.cfg.MODEL.DiffusionDet.NUM_CLASSES
# +1 because fast_rcnn_inference expects background scores as well
all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device)
for idx, cls, score in zip(count(), all_classes, all_scores):
all_scores_2d[idx, cls] = score
merged_instances, _ = fast_rcnn_inference_single_image(
all_boxes,
all_scores_2d,
shape_hw,
1e-8,
self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
self.cfg.TEST.DETECTIONS_PER_IMAGE,
)
return merged_instances
def _merge_detections_cvpods_tta(self, all_boxes, all_scores, all_classes, shape_hw):
all_scores = torch.tensor(all_scores).to(all_boxes.device)
all_classes = torch.tensor(all_classes).to(all_boxes.device)
all_boxes, all_scores, all_classes = self.merge_result_from_multi_scales(
all_boxes, all_scores, all_classes,
nms_type="soft_vote", vote_thresh=0.65,
max_detection=self.max_detection
)
all_boxes = Boxes(all_boxes)
all_boxes.clip(shape_hw)
result = Instances(shape_hw)
result.pred_boxes = all_boxes
result.scores = all_scores
result.pred_classes = all_classes.long()
return result
def merge_result_from_multi_scales(
self, boxes, scores, labels, nms_type="soft-vote", vote_thresh=0.65, max_detection=100
):
boxes, scores, labels = self.batched_vote_nms(
boxes, scores, labels, nms_type, vote_thresh
)
number_of_detections = boxes.shape[0]
# Limit to max_per_image detections **over all classes**
if number_of_detections > max_detection > 0:
boxes = boxes[:max_detection]
scores = scores[:max_detection]
labels = labels[:max_detection]
return boxes, scores, labels
def batched_vote_nms(self, boxes, scores, labels, vote_type, vote_thresh=0.65):
# apply per class level nms, add max_coordinates on boxes first, then remove it.
labels = labels.float()
max_coordinates = boxes.max() + 1
offsets = labels.reshape(-1, 1) * max_coordinates
boxes = boxes + offsets
boxes, scores, labels = self.bbox_vote(boxes, scores, labels, vote_thresh, vote_type)
boxes -= labels.reshape(-1, 1) * max_coordinates
return boxes, scores, labels
def bbox_vote(self, boxes, scores, labels, vote_thresh, vote_type="softvote"):
assert boxes.shape[0] == scores.shape[0] == labels.shape[0]
det = torch.cat((boxes, scores.reshape(-1, 1), labels.reshape(-1, 1)), dim=1)
vote_results = torch.zeros(0, 6, device=det.device)
if det.numel() == 0:
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5]
order = scores.argsort(descending=True)
det = det[order]
while det.shape[0] > 0:
# IOU
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
xx1 = torch.max(det[0, 0], det[:, 0])
yy1 = torch.max(det[0, 1], det[:, 1])
xx2 = torch.min(det[0, 2], det[:, 2])
yy2 = torch.min(det[0, 3], det[:, 3])
w = torch.clamp(xx2 - xx1, min=0.)
h = torch.clamp(yy2 - yy1, min=0.)
inter = w * h
iou = inter / (area[0] + area[:] - inter)
# get needed merge det and delete these det
merge_index = torch.where(iou >= vote_thresh)[0]
vote_det = det[merge_index, :]
det = det[iou < vote_thresh]
if merge_index.shape[0] <= 1:
vote_results = torch.cat((vote_results, vote_det), dim=0)
else:
if vote_type == "soft_vote":
vote_det_iou = iou[merge_index]
det_accu_sum = self.get_soft_dets_sum(vote_det, vote_det_iou)
elif vote_type == "vote":
det_accu_sum = self.get_dets_sum(vote_det)
vote_results = torch.cat((vote_results, det_accu_sum), dim=0)
order = vote_results[:, 4].argsort(descending=True)
vote_results = vote_results[order, :]
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5]
@staticmethod
def get_dets_sum(vote_det):
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4)
max_score = vote_det[:, 4].max()
det_accu_sum = torch.zeros((1, 6), device=vote_det.device)
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4])
det_accu_sum[:, 4] = max_score
det_accu_sum[:, 5] = vote_det[0, 5]
return det_accu_sum
@staticmethod
def get_soft_dets_sum(vote_det, vote_det_iou):
soft_vote_det = vote_det.detach().clone()
soft_vote_det[:, 4] *= (1 - vote_det_iou)
INFERENCE_TH = 0.05
soft_index = torch.where(soft_vote_det[:, 4] >= INFERENCE_TH)[0]
soft_vote_det = soft_vote_det[soft_index, :]
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4)
max_score = vote_det[:, 4].max()
det_accu_sum = torch.zeros((1, 6), device=vote_det.device)
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4])
det_accu_sum[:, 4] = max_score
det_accu_sum[:, 5] = vote_det[0, 5]
if soft_vote_det.shape[0] > 0:
det_accu_sum = torch.cat((det_accu_sum, soft_vote_det), dim=0)
return det_accu_sum
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
import numpy as np
def colormap(rgb=False):
color_list = np.array(
[
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.000, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.000, 0.000, 0.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857,
1.000, 1.000, 1.000
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
def category():
category = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush"]
return category
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from typing import Optional, List
import torch
import torch.distributed as dist
from torch import Tensor
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
TORCH_MAJOR = int(torchvision.__version__.split('.')[0])
TORCH_MINOR = int(torchvision.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR < 7:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if float(torchvision.__version__[:3]) < 0.7:
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import math
import itertools
import logging
from typing import Dict, Any
from contextlib import contextmanager
import torch
from detectron2.engine.train_loop import HookBase
from detectron2.checkpoint import DetectionCheckpointer
logger = logging.getLogger(__name__)
class EMADetectionCheckpointer(DetectionCheckpointer):
def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]:
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists
and load the model together with all the checkpointables. Otherwise
only load the model without loading any checkpointables.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
# workaround `self.load`
return self.load(path, checkpointables=None) # modify
class EMAState(object):
def __init__(self):
self.state = {}
@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])
@contextmanager
def apply_and_restore(self, model):
old_state = EMAState.FromModel(model, self.device)
self.apply_to(model)
yield old_state
old_state.apply_to(model)
def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret
@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device
def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self
def has_inited(self):
return self.state
def clear(self):
self.state.clear()
return self
def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def state_dict(self):
return self.state
def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)
def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret
class EMAUpdater(object):
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = "", yolox: bool = False):
self.decay = decay
self.device = device
self.state = state
self.updates = 0
self.yolox = yolox
if yolox:
decay = 0.9998
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates) if self.yolox else self.decay
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
ema_val.copy_(ema_val * d + val * (1.0 - d))
def add_model_ema_configs(_C):
_C.MODEL_EMA = type(_C)()
_C.MODEL_EMA.ENABLED = False
_C.MODEL_EMA.DECAY = 0.999
# use the same as MODEL.DEVICE when empty
_C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model()
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
# when True, use YOLOX EMA: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/ema.py#L22
_C.MODEL_EMA.YOLOX = False
def _remove_ddp(model):
from torch.nn.parallel import DistributedDataParallel
if isinstance(model, DistributedDataParallel):
return model.module
return model
def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState()
logger.info("Using Model EMA.")
def may_get_ema_checkpointer(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return {}
model = _remove_ddp(model)
return {"ema_state": model.ema_state}
def get_model_ema_state(model):
"""Return the ema state stored in `model`"""
model = _remove_ddp(model)
assert hasattr(model, "ema_state")
ema = model.ema_state
return ema
def apply_model_ema(model, state=None, save_current=False):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
if save_current:
# save current model state
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
if save_current:
return old_state
return None
@contextmanager
def apply_model_ema_and_restore(model, state=None):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
yield old_state
old_state.apply_to(model)
class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.MODEL_EMA.ENABLED
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device, yolox=cfg.MODEL_EMA.YOLOX
)
def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)
def after_train(self):
pass
def before_step(self):
pass
def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
'''
Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
- fields = which results to plot from each log file - plots both training and test for each field.
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
- log_name = optional, name of log file if different than default 'log.txt'.
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results.
'''
func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
# convert single Path to list to avoid 'not iterable' error
if not isinstance(logs, list):
if isinstance(logs, PurePath):
logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}")
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for i, dir in enumerate(logs):
if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
# verify log_name exists
fn = Path(dir / log_name)
if not fn.exists():
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
print(f"--> full path of missing log file: {fn}")
return
# load log file(s) and plot
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs
icon.png

77.3 KB

# 模型唯一标识
modelCode=901
# 模型名称
modelName=diffusiondet_pytorch
# 模型描述
modelDescription= Diffusion Model for Object Detection
# 应用场景
appScenario=推理,训练,科研,制造,医疗,家居,教育
# 框架类型
frameType=pytorch
# ==========================================
# Modified by Shoufa Chen
# ===========================================
# Modified by Peize Sun, Rufeng Zhang
# Contact: {sunpeize, cxrfzhang}@foxmail.com
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DiffusionDet Training Script.
This script is a simplified version of the training script in detectron2/tools.
"""
import os
import itertools
import weakref
from typing import Any, Dict, List, Set
import logging
from collections import OrderedDict
import torch
from fvcore.nn.precise_bn import get_bn_modules
import detectron2.utils.comm as comm
from detectron2.utils.logger import setup_logger
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, create_ddp_model, \
AMPTrainer, SimpleTrainer, hooks
from detectron2.evaluation import COCOEvaluator, LVISEvaluator, verify_results
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.modeling import build_model
from diffusiondet import DiffusionDetDatasetMapper, add_diffusiondet_config, DiffusionDetWithTTA
from diffusiondet.util.model_ema import add_model_ema_configs, may_build_model_ema, may_get_ema_checkpointer, EMAHook, \
apply_model_ema_and_restore, EMADetectionCheckpointer
class Trainer(DefaultTrainer):
""" Extension of the Trainer class adapted to DiffusionDet. """
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super(DefaultTrainer, self).__init__() # call grandfather's `__init__` while avoid father's `__init()`
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
########## EMA ############
kwargs = {
'trainer': weakref.proxy(self),
}
kwargs.update(may_get_ema_checkpointer(cfg, model))
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
**kwargs,
# trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
# setup EMA
may_build_model_ema(cfg, model)
return model
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
if 'lvis' in dataset_name:
return LVISEvaluator(dataset_name, cfg, True, output_folder)
else:
return COCOEvaluator(dataset_name, cfg, True, output_folder)
@classmethod
def build_train_loader(cls, cfg):
mapper = DiffusionDetDatasetMapper(cfg, is_train=True)
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_optimizer(cls, cfg, model):
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "backbone" in key:
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
@classmethod
def ema_test(cls, cfg, model, evaluators=None):
# model with ema weights
logger = logging.getLogger("detectron2.trainer")
if cfg.MODEL_EMA.ENABLED:
logger.info("Run evaluation with EMA.")
with apply_model_ema_and_restore(model):
results = cls.test(cfg, model, evaluators=evaluators)
else:
results = cls.test(cfg, model, evaluators=evaluators)
return results
@classmethod
def test_with_TTA(cls, cfg, model):
logger = logging.getLogger("detectron2.trainer")
logger.info("Running inference with test-time augmentation ...")
model = DiffusionDetWithTTA(cfg, model)
evaluators = [
cls.build_evaluator(
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
)
for name in cfg.DATASETS.TEST
]
if cfg.MODEL_EMA.ENABLED:
cls.ema_test(cfg, model, evaluators)
else:
res = cls.test(cfg, model, evaluators)
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # EMA hook
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_diffusiondet_config(cfg)
add_model_ema_configs(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
kwargs = may_get_ema_checkpointer(cfg, model)
if cfg.MODEL_EMA.ENABLED:
EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,
resume=args.resume)
else:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,
resume=args.resume)
res = Trainer.ema_test(cfg, model)
if cfg.TEST.AUG.ENABLED:
res.update(Trainer.test_with_TTA(cfg, model))
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment