Commit c480d4e4 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

stabilize the training of deformable DETR with box refinement

Summary:
Major changes
- As described in details in appendix A.4 in deformable DETR paper (https://arxiv.org/abs/2010.04159), the gradient back-propagation is blocked at inverse_sigmoid(bounding box x/y/w/h from last decoder layer). This can be implemented by detaching tensor from compute graph in pytorch. However, currently we detach at an incorrect tensor, preventing update the layers which predicts delta x/y/w/h. Fix this bug.
- Add more comments to annotate data types and tensor shape in the code. This should NOT affect the actual implementation.

Reviewed By: zhanghang1989

Differential Revision: D29048363

fbshipit-source-id: c5b5e89793c86d530b077a7b999769881f441b69
parent 37947353
......@@ -4,29 +4,28 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import ShapeSpec
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detr.datasets.coco import convert_coco_poly_to_mask
from detr.models.backbone import Joiner
from detr.models.detr import DETR
from detr.models.deformable_detr import DeformableDETR
from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion
from detr.models.deformable_transformer import DeformableTransformer
from detr.models.detr import DETR
from detr.models.matcher import HungarianMatcher
from detr.models.position_encoding import PositionEmbeddingSine
from detr.models.transformer import Transformer
from detr.models.deformable_transformer import DeformableTransformer
from detr.models.segmentation import DETRsegm, PostProcessSegm
from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion
from detr.models.transformer import Transformer
from detr.util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from detr.util.misc import NestedTensor
from detr.datasets.coco import convert_coco_poly_to_mask
from torch import nn
__all__ = ["Detr"]
class ResNetMaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking"""
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
......@@ -48,6 +47,7 @@ class ResNetMaskedBackbone(nn.Module):
def forward(self, images):
features = self.backbone(images.tensor)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
......@@ -63,7 +63,9 @@ class ResNetMaskedBackbone(nn.Module):
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device)
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
......@@ -73,16 +75,21 @@ class ResNetMaskedBackbone(nn.Module):
masks.append(masks_per_feature_level)
return masks
class FBNetMaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking"""
"""This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES
self.feature_strides = list(self.backbone._out_feature_strides.values())
self.num_channels = [self.backbone._out_feature_channels[k] for k in self.out_features]
self.strides = [self.backbone._out_feature_strides[k] for k in self.out_features]
self.num_channels = [
self.backbone._out_feature_channels[k] for k in self.out_features
]
self.strides = [
self.backbone._out_feature_strides[k] for k in self.out_features
]
def forward(self, images):
features = self.backbone(images.tensor)
......@@ -103,7 +110,9 @@ class FBNetMaskedBackbone(nn.Module):
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device)
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
......@@ -147,14 +156,19 @@ class Detr(nn.Module):
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
N_steps = hidden_dim // 2
if 'resnet' in cfg.MODEL.BACKBONE.NAME.lower():
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg)
elif 'fbnet' in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone =FBNetMaskedBackbone(cfg)
elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg)
else:
raise NotImplementedError
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True, centered=centered_position_encoding))
backbone = Joiner(
d2_backbone,
PositionEmbeddingSine(
N_steps, normalize=True, centered=centered_position_encoding
),
)
backbone.num_channels = d2_backbone.num_channels
self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
......@@ -171,13 +185,19 @@ class Detr(nn.Module):
num_feature_levels=num_feature_levels,
dec_n_points=4,
enc_n_points=4,
two_stage=False,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
two_stage_num_proposals=num_queries,
)
self.detr = DeformableDETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries,
num_feature_levels=num_feature_levels, aux_loss=deep_supervision,
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
num_feature_levels=num_feature_levels,
aux_loss=deep_supervision,
with_box_refine=cfg.MODEL.DETR.WITH_BOX_REFINE,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
)
else:
transformer = Transformer(
......@@ -192,31 +212,41 @@ class Detr(nn.Module):
)
self.detr = DETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries,
aux_loss=deep_supervision, use_focal_loss=self.use_focal_loss,
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
aux_loss=deep_supervision,
use_focal_loss=self.use_focal_loss,
)
if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != '':
if frozen_weights != "":
print("LOAD pre-trained weights")
weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model']
weight = torch.load(
frozen_weights, map_location=lambda storage, loc: storage
)["model"]
new_weight = {}
for k, v in weight.items():
if 'detr.' in k:
new_weight[k.replace('detr.', '')] = v
if "detr." in k:
new_weight[k.replace("detr.", "")] = v
else:
print(f"Skipping loading weight {k} from frozen model")
del weight
self.detr.load_state_dict(new_weight)
del new_weight
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ''))
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ""))
self.seg_postprocess = PostProcessSegm
self.detr.to(self.device)
# building criterion
matcher = HungarianMatcher(cost_class=cls_weight, cost_bbox=l1_weight,
cost_giou=giou_weight, use_focal_loss=self.use_focal_loss)
matcher = HungarianMatcher(
cost_class=cls_weight,
cost_bbox=l1_weight,
cost_giou=giou_weight,
use_focal_loss=self.use_focal_loss,
)
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
weight_dict["loss_giou"] = giou_weight
if deep_supervision:
......@@ -229,11 +259,18 @@ class Detr(nn.Module):
losses += ["masks"]
if self.use_focal_loss:
self.criterion = FocalLossSetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, losses=losses,
self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
losses=losses,
)
else:
self.criterion = SetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=no_object_weight,
losses=losses,
)
self.criterion.to(self.device)
......@@ -266,6 +303,9 @@ class Detr(nn.Module):
if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
# targets: List[Dict[str, torch.Tensor]]. Keys
# "labels": [NUM_BOX,]
# "boxes": [NUM_BOX, 4]
targets = self.prepare_targets(gt_instances)
loss_dict = self.criterion(output, targets)
weight_dict = self.criterion.weight_dict
......@@ -279,7 +319,9 @@ class Detr(nn.Module):
mask_pred = output["pred_masks"] if self.mask_on else None
results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes)
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
......@@ -290,15 +332,17 @@ class Detr(nn.Module):
new_targets = []
for targets_per_image in targets:
h, w = targets_per_image.image_size
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
gt_classes = targets_per_image.gt_classes
image_size_xyxy = torch.as_tensor(
[w, h, w, h], dtype=torch.float, device=self.device
)
gt_classes = targets_per_image.gt_classes # shape (NUM_BOX,)
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
gt_boxes = box_xyxy_to_cxcywh(gt_boxes) # shape (NUM_BOX, 4)
new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
if self.mask_on and hasattr(targets_per_image, "gt_masks"):
gt_masks = targets_per_image.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
new_targets[-1].update({'masks': gt_masks})
new_targets[-1].update({"masks": gt_masks})
return new_targets
def inference(self, box_cls, box_pred, mask_pred, image_sizes):
......@@ -321,27 +365,41 @@ class Detr(nn.Module):
if self.use_focal_loss:
prob = box_cls.sigmoid()
# TODO make top-100 as an option for non-focal-loss as well
scores, topk_indexes = torch.topk(prob.view(box_cls.shape[0], -1), 100, dim=1)
scores, topk_indexes = torch.topk(
prob.view(box_cls.shape[0], -1), 100, dim=1
)
topk_boxes = topk_indexes // box_cls.shape[2]
labels = topk_indexes % box_cls.shape[2]
else:
scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip(
scores, labels, box_pred, image_sizes
)):
for i, (
scores_per_image,
labels_per_image,
box_pred_per_image,
image_size,
) in enumerate(zip(scores, labels, box_pred, image_sizes)):
result = Instances(image_size)
boxes = box_cxcywh_to_xyxy(box_pred_per_image)
if self.use_focal_loss:
boxes = torch.gather(boxes.unsqueeze(0), 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)).squeeze()
boxes = torch.gather(
boxes.unsqueeze(0), 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)
).squeeze()
result.pred_boxes = Boxes(boxes)
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0])
if self.mask_on:
mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False)
mask = F.interpolate(
mask_pred[i].unsqueeze(0),
size=image_size,
mode="bilinear",
align_corners=False,
)
mask = mask[0].sigmoid() > 0.5
B, N, H, W = mask_pred.shape
mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32)
mask = BitMasks(mask.cpu()).crop_and_resize(
result.pred_boxes.tensor.cpu(), 32
)
result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device)
result.scores = scores_per_image
......
......@@ -128,6 +128,7 @@ class Joiner(nn.Sequential):
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))
# shape a list of tensors, each tensor shape (B, C, H, W)
return out, pos
......
......@@ -10,23 +10,33 @@
"""
Deformable DETR model and criterion classes.
"""
import copy
import math
import torch
import torch.nn.functional as F
from torch import nn
import math
from ..util import box_ops
from ..util.misc import (NestedTensor, nested_tensor_from_tensor_list,
accuracy, get_world_size, interpolate,
is_dist_avail_and_initialized, inverse_sigmoid)
from ..util.misc import (
NestedTensor,
nested_tensor_from_tensor_list,
accuracy,
get_world_size,
interpolate,
is_dist_avail_and_initialized,
inverse_sigmoid,
)
from .backbone import build_backbone
from .matcher import build_matcher
from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
dice_loss, sigmoid_focal_loss)
from .deformable_transformer import build_deforamble_transformer
import copy
from .matcher import build_matcher
from .segmentation import (
DETRsegm,
PostProcessPanoptic,
PostProcessSegm,
dice_loss,
sigmoid_focal_loss,
)
from .setcriterion import FocalLossSetCriterion
......@@ -35,10 +45,20 @@ def _get_clones(module, N):
class DeformableDETR(nn.Module):
""" This is the Deformable DETR module that performs object detection """
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
aux_loss=True, with_box_refine=False, two_stage=False):
""" Initializes the model.
"""This is the Deformable DETR module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_classes,
num_queries,
num_feature_levels,
aux_loss=True,
with_box_refine=False,
two_stage=False,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
......@@ -57,29 +77,38 @@ class DeformableDETR(nn.Module):
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels
if not two_stage:
self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
if num_feature_levels > 1:
num_backbone_outs = len(backbone.strides)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
))
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
))
input_proj_list.append(
nn.Sequential(
nn.Conv2d(
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)])
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.with_box_refine = with_box_refine
......@@ -95,7 +124,11 @@ class DeformableDETR(nn.Module):
nn.init.constant_(proj[0].bias, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
num_pred = (
(transformer.decoder.num_layers + 1)
if two_stage
else transformer.decoder.num_layers
)
if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
......@@ -104,7 +137,9 @@ class DeformableDETR(nn.Module):
self.transformer.decoder.bbox_embed = self.bbox_embed
else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
self.class_embed = nn.ModuleList(
[self.class_embed for _ in range(num_pred)]
)
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.transformer.decoder.bbox_embed = None
if two_stage:
......@@ -114,31 +149,37 @@ class DeformableDETR(nn.Module):
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
# features is a list of num_levels NestedTensor.
# pos is a list of num_levels tensors. Each one has shape (B, H_l, W_l).
features, pos = self.backbone(samples)
# srcs is a list of num_levels tensor. Each one has shape (B, C, H_l, W_l)
srcs = []
# masks is a list of num_levels tensor. Each one has shape (B, H_l, W_l)
masks = []
for l, feat in enumerate(features):
# src shape: (N, C, H_l, W_l)
# mask shape: (N, H_l, W_l)
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
N, C, H, W = samples.tensor.size()
sample_mask = torch.ones((N, H, W), dtype=torch.bool, device=src.device)
......@@ -146,6 +187,7 @@ class DeformableDETR(nn.Module):
image_size = samples.image_sizes[idx]
h, w = image_size
sample_mask[idx, :h, :w] = False
# sample_mask shape (1, N, H, W)
sample_mask = sample_mask[None].float()
_len_srcs = len(srcs)
......@@ -163,8 +205,19 @@ class DeformableDETR(nn.Module):
query_embeds = None
if not self.two_stage:
# shape (num_queries, hidden_dim*2)
query_embeds = self.query_embed.weight
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
# hs shape: (num_layers, batch_size, num_queries, c)
# init_reference shape: (num_queries, 2)
# inter_references shape: (num_layers, bs, num_queries, num_levels, 2)
(
hs,
init_reference,
inter_references,
enc_outputs_class,
enc_outputs_coord_unact,
) = self.transformer(srcs, masks, pos, query_embeds)
outputs_classes = []
outputs_coords = []
......@@ -173,27 +226,49 @@ class DeformableDETR(nn.Module):
reference = init_reference
else:
reference = inter_references[lvl - 1]
# reference shape: (num_queries, 2)
reference = inverse_sigmoid(reference)
# shape (batch_size, num_queries, num_classes)
outputs_class = self.class_embed[lvl](hs[lvl])
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(
torch.isnan(hs[lvl])
), f"lvl {lvl}, NaN hs[lvl] {hs[lvl]}"
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(torch.isnan(tmp)), f"NaN tmp {tmp}"
outputs_coord = tmp.sigmoid()
assert not torch.any(
torch.isnan(outputs_coord)
), f"NaN outputs_coord {outputs_coord}"
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
# shape (num_levels, batch_size, num_queries, num_classes)
outputs_class = torch.stack(outputs_classes)
# shape (num_levels, batch_size, num_queries, 4)
outputs_coord = torch.stack(outputs_coords)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
if self.two_stage:
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
out["enc_outputs"] = {
"pred_logits": enc_outputs_class,
"pred_boxes": enc_outputs_coord,
}
return out
@torch.jit.unused
......@@ -201,53 +276,62 @@ class DeformableDETR(nn.Module):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
"""This module converts the model's output into the format expected by the coco api"""
@torch.no_grad()
def forward(self, outputs, target_sizes):
""" Perform the computation
"""Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1), 100, dim=1
)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
results = [
{"scores": s, "labels": l, "boxes": b}
for s, l, b in zip(scores, labels, boxes)
]
return results
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
......@@ -256,7 +340,7 @@ class MLP(nn.Module):
def build(args):
num_classes = 20 if args.dataset_file != 'coco' else 91
num_classes = 20 if args.dataset_file != "coco" else 91
if args.dataset_file == "coco_panoptic":
num_classes = 250
device = torch.device(args.device)
......@@ -277,8 +361,8 @@ def build(args):
if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef
weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
weight_dict["loss_giou"] = args.giou_loss_coef
if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef
......@@ -286,21 +370,25 @@ def build(args):
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ['labels', 'boxes', 'cardinality']
losses = ["labels", "boxes", "cardinality"]
if args.masks:
losses += ["masks"]
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
criterion = FocalLossSetCriterion(num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha)
criterion = FocalLossSetCriterion(
num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha
)
criterion.to(device)
postprocessors = {'bbox': PostProcess()}
postprocessors = {"bbox": PostProcess()}
if args.masks:
postprocessors['segm'] = PostProcessSegm()
postprocessors["segm"] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
postprocessors["panoptic"] = PostProcessPanoptic(
is_thing_map, threshold=0.85
)
return model, criterion, postprocessors
......@@ -8,24 +8,36 @@
# ------------------------------------------------------------------------
import copy
from typing import Optional, List
# import logging
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_
from ..util.misc import inverse_sigmoid
from ..modules import MSDeformAttn
from ..util.misc import inverse_sigmoid
class DeformableTransformer(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
activation="relu", return_intermediate_dec=False,
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
two_stage=False, two_stage_num_proposals=300):
def __init__(
self,
d_model=256,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation="relu",
return_intermediate_dec=False,
num_feature_levels=4,
dec_n_points=4,
enc_n_points=4,
two_stage=False,
two_stage_num_proposals=300,
):
super().__init__()
self.d_model = d_model
......@@ -33,15 +45,29 @@ class DeformableTransformer(nn.Module):
self.two_stage = two_stage
self.two_stage_num_proposals = two_stage_num_proposals
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, enc_n_points)
encoder_layer = DeformableTransformerEncoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
enc_n_points,
)
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, dec_n_points)
self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
)
self.decoder = DeformableTransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec
)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
......@@ -64,66 +90,118 @@ class DeformableTransformer(nn.Module):
m._reset_parameters()
if not self.two_stage:
xavier_uniform_(self.reference_points.weight.data, gain=1.0)
constant_(self.reference_points.bias.data, 0.)
constant_(self.reference_points.bias.data, 0.0)
normal_(self.level_embed)
def get_proposal_pos_embed(self, proposals):
"""
Args
proposals: shape (bs, top_k, 4). Last dimension of size 4 denotes (cx, cy, w, h)
"""
num_pos_feats = 128
temperature = 10000
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
# shape (num_pos_feats)
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device
)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
# pos shape: (bs, top_k, 4, num_pos_feats)
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
# pos shape: (bs, top_k, 4, num_pos_feats/2, 2) -> (bs, top_k, 4 * num_pos_feats)
pos = torch.stack(
(pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
).flatten(2)
# pos shape: (bs, top_k, 4 * num_pos_feats) = (bs, top_k, 512)
return pos
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
"""
Args:
memory: shape (bs, K, C) where K = \sum_l H_l * w_l
memory_padding_mask: shape (bs, K)
spatial_shapes: shape (num_levels, 2)
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
# shape (bs, H_l * W_l)
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
N_, H_, W_, 1
)
# shape (bs, )
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
# shape (bs, )
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
# grid_y, grid_x shape (H_l, W_l)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H_ - 1, H_, dtype=torch.float32, device=memory.device
),
torch.linspace(
0, W_ - 1, W_, dtype=torch.float32, device=memory.device
),
)
# grid shape (H_l, W_l, 2)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
# scale shape (bs, 1, 1, 2)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
N_, 1, 1, 2
)
# grid shape (bs, H_l, W_l, 2). Value could be > 1
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh shape (bs, H_l, W_l, 2)
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
# proposal shape (bs, H_l * W_l, 4)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
_cur += H_ * W_
# shape (bs, K, 4) where K = \sum_l H_l * W_l
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
# shape (bs, K, 1)
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)
).all(-1, keepdim=True)
# inverse sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0)
)
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
def get_valid_ratio(self, mask):
_, H, W = mask.shape
# shape (bs,)
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
# shape (bs, 2)
return valid_ratio
def forward(self, srcs, masks, pos_embeds, query_embed=None):
"""
Args:
srcs: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
masks: a list of num_levels tensors. Each has shape (N, H_l, W_l)
pos_embeds: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
query_embed: a tensor has shape (num_queries, C)
"""
assert self.two_stage or query_embed is not None
# prepare input for encoder
......@@ -135,62 +213,134 @@ class DeformableTransformer(nn.Module):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
# src shape (bs, h_l*w_l, c)
src = src.flatten(2).transpose(1, 2)
# mask shape (bs, h_l*w_l)
mask = mask.flatten(1)
# pos_embed shape (bs, h_l*w_l, c)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
# lvl_pos_embed shape (bs, h_l*w_l, c)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
# src_flatten shape: (bs, K, c) where K = \sum_l H_l * w_l
src_flatten = torch.cat(src_flatten, 1)
# mask_flatten shape: (bs, K)
mask_flatten = torch.cat(mask_flatten, 1)
# mask_flatten shape: (bs, K, c)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
# spatial_shapes shape: (num_levels, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
# level_start_index shape: (num_levels)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
# spatial_shapes shape: (bs, num_levels, 2)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
# memory shape (bs, K, C) where K = \sum_l H_l * w_l
memory = self.encoder(
src_flatten,
spatial_shapes,
level_start_index,
valid_ratios,
lvl_pos_embed_flatten,
mask_flatten,
)
# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# output_memory shape (bs, K, C). Value = 0
# output_proposals shape (bs, K, 4)
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
# hack implementation for two-stage Deformable DETR
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
# shape (bs, K, num_classes)
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
output_memory
)
# shape (bs, K, 4)
enc_outputs_coord_unact = (
self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ output_proposals
)
topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k)
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
# topk_coords_unact shape (bs, top_k, 4)
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_unact = topk_coords_unact.detach()
# reference_points shape (bs, top_k, 4). value \in (0, 1)
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
# shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
)
# query_embed shape (bs, top_k, c)
# tgt shape (bs, top_k, c)
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
else:
# query_embed (or tgt) shape: (num_queries, c)
query_embed, tgt = torch.split(query_embed, c, dim=1)
# query_embed shape: (batch_size, num_queries, c)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# reference_points shape: (batch_size, num_queries, 2), value \in (0, 1)
reference_points = self.reference_points(query_embed).sigmoid()
init_reference_out = reference_points
# decoder
hs, inter_references = self.decoder(tgt, reference_points, memory,
spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)
# hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder(
tgt,
reference_points,
memory,
spatial_shapes,
level_start_index,
valid_ratios,
query_embed,
mask_flatten,
)
inter_references_out = inter_references
if self.two_stage:
return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
return (
hs,
init_reference_out,
inter_references_out,
enc_outputs_class,
enc_outputs_coord_unact,
)
# hs shape: (num_layers, batch_size, num_queries, c)
# init_reference_out shape: (batch_size, num_queries, 2)
# inter_references_out shape: (num_layers, bs, num_queries, num_levels, 2)
return hs, init_reference_out, inter_references_out, None, None
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
......@@ -216,9 +366,34 @@ class DeformableTransformerEncoderLayer(nn.Module):
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
def forward(
self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask=None,
):
"""
Args:
src: tensor, shape (bs, K, c) where K = \sum_l H_l * w_l
pos: tensor, shape (bs, K, c)
reference_points: tensor, shape (bs, K, num_levels, 2)
spatial_shapes: tensor, shape (num_levels, 2)
level_start_index: tensor, shape (num_levels,)
padding_mask: tensor, shape: (bs, K)
"""
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
# shape: (bs, \sum_l H_l * w_l, c)
src2 = self.self_attn(
self.with_pos_embed(src, pos),
reference_points,
src,
spatial_shapes,
level_start_index,
padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
......@@ -238,30 +413,68 @@ class DeformableTransformerEncoder(nn.Module):
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
# ref_y shape: (H_l, W_l)
# ref_x shape: (H_l, W_l)
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
# ref_y
# shape (None, H_l*W_l) / (N, None) = (N, H_l*W_l)
# value could be >1
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
# ref shape (N, H_l*W_l, 2)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
# shape (N, K, 2) where K = \sum_l (H_l * W_l)
reference_points = torch.cat(reference_points_list, 1)
# reference_points
# shape (N, K, 1, 2) * (N, 1, num_levels, 2) = (N, K, num_levels, 2)
# value should be <1
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
assert not torch.any(
torch.isnan(reference_points)
), f"nan in reference_points {reference_points}"
return reference_points
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
def forward(
self,
src,
spatial_shapes,
level_start_index,
valid_ratios,
pos=None,
padding_mask=None,
):
output = src
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
output = layer(
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask,
)
# shape (bs, K, c) where K = \sum_l H_l * w_l
return output
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# cross attention
......@@ -292,23 +505,50 @@ class DeformableTransformerDecoderLayer(nn.Module):
tgt = self.norm3(tgt)
return tgt
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
def forward(
self,
tgt,
query_pos,
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask=None,
):
"""
Args:
tgt: tensor, shape (batch_size, num_queries, c)
query_pos: tensor, shape: (batch_size, num_queries, c)
reference_points: tensor, shape: (batch_size, num_queries, num_levels, 2). values \in (0, 1)
src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l
src_spatial_shapes: tensor, shape (num_levels, 2)
level_start_index: tensor, shape (num_levels,)
src_padding_mask: tensor, (batch_size, K)
"""
# self attention
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
tgt2 = self.self_attn(
q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1)
)[0].transpose(0, 1)
# tgt shape: (batch_size, num_queries, c)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# cross attention
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
reference_points,
src, src_spatial_shapes, level_start_index, src_padding_mask)
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, query_pos),
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask,
)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
# tgt shape: (batch_size, num_queries, c)
return tgt
......@@ -322,41 +562,89 @@ class DeformableTransformerDecoder(nn.Module):
self.bbox_embed = None
self.class_embed = None
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
query_pos=None, src_padding_mask=None):
def forward(
self,
tgt,
reference_points,
src,
src_spatial_shapes,
src_level_start_index,
src_valid_ratios,
query_pos=None,
src_padding_mask=None,
):
"""
Args:
tgt: tensor, shape (batch_size, num_queries, c)
reference_points: tensor, shape (batch_size, num_queries, 2 or 4).
values \in (0, 1)
src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l
src_spatial_shapes: tensor, shape (num_levels, 2)
src_level_start_index: tensor, shape (num_levels,)
src_valid_ratios: tensor, shape (batch_size, num_levels, 2)
query_pos: tensor, shape: (batch_size, num_queries, c)
src_padding_mask: tensor, (bs, K)
"""
output = tgt
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
# shape: (bs, num_queries, 1, 4) * (bs, 1, num_levels, 4) = (bs, num_queries, num_levels, 4)
reference_points_input = (
reference_points[:, :, None]
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
)
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
# shape (bs, num_queries, 1, 2) * (bs, 1, num_levels, 2) = (bs, num_queries, num_levels, 2)
reference_points_input = (
reference_points[:, :, None] * src_valid_ratios[:, None]
)
# shape: (bs, num_queries, c)
output = layer(
output,
query_pos,
reference_points_input,
src,
src_spatial_shapes,
src_level_start_index,
src_padding_mask,
)
assert not torch.any(torch.isnan(output)), f"NaN, lid {lid}, {output}"
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output)
reference_points_unact = inverse_sigmoid(reference_points)
# block gradient backpropagation here to avoid instable optimization
reference_points_unact_detach = reference_points_unact.detach()
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
new_reference_points = tmp + reference_points_unact_detach
assert not torch.any(
torch.isnan(new_reference_points)
), f"NaN, reference_points {reference_points}, new_reference_points {new_reference_points}"
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
new_reference_points[..., :2] = tmp[..., :2] + reference_points_unact_detach
reference_points = new_reference_points.sigmoid()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
# shape 1: (num_layers, batch_size, num_queries, c)
# shape 2: (num_layers, bs, num_queries, num_levels, 2)
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
# output shape: (batch_size, num_queries, c)
# reference_points shape: (bs, num_queries, num_levels, 2)
return output, reference_points
......@@ -372,7 +660,7 @@ def _get_activation_fn(activation):
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def build_deforamble_transformer(args):
......@@ -389,6 +677,5 @@ def build_deforamble_transformer(args):
dec_n_points=args.dec_n_points,
enc_n_points=args.enc_n_points,
two_stage=args.two_stage,
two_stage_num_proposals=args.num_queries)
two_stage_num_proposals=args.num_queries,
)
......@@ -63,12 +63,18 @@ class DETR(nn.Module):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
# src shape (B, C, H, W)
# mask shape (B, H, W)
src, mask = features[-1].decompose()
assert mask is not None
# hs shape (NUM_LAYER, B, S, hidden_dim)
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
# shape (NUM_LAYER, B, S, NUM_CLASS + 1)
outputs_class = self.class_embed(hs)
# shape (NUM_LAYER, B, S, 4)
outputs_coord = self.bbox_embed(hs).sigmoid()
# pred_logits shape (B, S, NUM_CLASS + 1)
# pred_boxes shape (B, S, 4)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
......
......@@ -65,8 +65,8 @@ class HungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
tgt_ids = torch.cat([v["labels"] for v in targets]) # [\sum_b NUM-BOX_b,]
tgt_bbox = torch.cat([v["boxes"] for v in targets]) # [\sum_b NUM-BOX_b, 4]
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
......@@ -78,20 +78,23 @@ class HungarianMatcher(nn.Module):
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
else:
cost_class = -out_prob[:, tgt_ids]
cost_class = -out_prob[:, tgt_ids] # shape [batch_size * num_queries, \sum_b NUM-BOX_b]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # shape [batch_size * num_queries,\sum_b NUM-BOX_b]
# Compute the giou cost betwen boxes
# shape [batch_size * num_queries, \sum_b NUM-BOX_b]
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
C = C.view(bs, num_queries, -1).cpu() # shape [batch_size, num_queries, \sum_b NUM-BOX_b]
sizes = [len(v["boxes"]) for v in targets]
sizes = [len(v["boxes"]) for v in targets] # shape [batch_size,]
# each split c shape [batch_size, num_queries, NUM-BOX_b]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
# A list where each item is [row_indices, col_indices]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
......
......@@ -33,7 +33,7 @@ class PositionEmbeddingSine(nn.Module):
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
y_embed = not_mask.cumsum(1, dtype=torch.float32) # shape (B, H, W)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
......@@ -45,13 +45,13 @@ class PositionEmbeddingSine(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # shape (N, )
pos_x = x_embed[:, :, :, None] / dim_t
pos_x = x_embed[:, :, :, None] / dim_t # shape (B, H, W, N)
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # shape (B, 2*N, H, W)
return pos
......
......@@ -39,10 +39,18 @@ class SetCriterion(nn.Module):
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
# shape (batch_size, num_queries, NUM_CLASS + 1)
src_logits = outputs['pred_logits']
# idx = (batch_idx, src_idx)
# batch_idx shape [\sum_b num_match_b]
# src_idx shape [\sum_b num_match_b]
idx = self._get_src_permutation_idx(indices)
# targets: List[Dict[str, torch.Tensor]]. Keys
# "labels": [NUM_BOX,]
# "boxes": [NUM_BOX, 4]
# target_classes_o shape [batch_size * num_match]
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
# shape (batch_size, num_queries)
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
......@@ -76,7 +84,9 @@ class SetCriterion(nn.Module):
"""
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
# shape [\sum_b num_matches_b, 4]
src_boxes = outputs['pred_boxes'][idx]
# shape [\sum_b num_matches_b, 4]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
......@@ -121,14 +131,14 @@ class SetCriterion(nn.Module):
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) # shape [\sum_b num_match_b]
src_idx = torch.cat([src for (src, _) in indices]) # shape [\sum_b num_match_b]
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) # shape [\sum_b num_match_b]
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) # shape [\sum_b num_match_b]
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
......@@ -148,9 +158,12 @@ class SetCriterion(nn.Module):
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
# "pred_logits" shape (B, S, NUM_CLASS + 1)
# "pred_boxes" shape (B, S, 4)
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
# A list where each item is [row_indices, col_indices]
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
......@@ -378,5 +391,3 @@ class FocalLossSetCriterion(nn.Module):
losses.update(l_dict)
return losses
......@@ -47,17 +47,25 @@ class Transformer(nn.Module):
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed):
# src shape (B, C, H, W)
# mask shape (B, H, W)
# query_embed shape (M, C)
# pos_embed shape (B, C, H, W)
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
mask = mask.flatten(1)
src = src.flatten(2).permute(2, 0, 1) # shape (L, B, C)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (L, B, C)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
mask = mask.flatten(1) # shape (B, HxW)
tgt = torch.zeros_like(query_embed)
# memory shape (L, B, C)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# hs shape (NUM_LEVEL, S, B, C)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
# return shape (NUM_LEVEL, B, S, C) and (B, C, H, W)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
......@@ -74,7 +82,8 @@ class TransformerEncoder(nn.Module):
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
# mask, shape (L, L)
# src_key_padding_mask, shape (B, L)
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
......@@ -104,7 +113,11 @@ class TransformerDecoder(nn.Module):
output = tgt
intermediate = []
# tgt shape (L, B, C)
# tgt_mask shape (L, L)
# tgt_key_padding_mask shape (B, L)
# memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S)
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
......@@ -122,7 +135,7 @@ class TransformerDecoder(nn.Module):
if self.return_intermediate:
return torch.stack(intermediate)
# return shape (NUM_LAYER, L, B, C)
return output.unsqueeze(0)
......@@ -153,7 +166,9 @@ class TransformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
q = k = self.with_pos_embed(src, pos) # shape (L, B, D)
# src mask, shape (L, L)
# src_key_padding_mask: shape (B, L)
src2 = self.self_attn(q, k, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
......@@ -218,11 +233,17 @@ class TransformerDecoderLayer(nn.Module):
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# tgt shape (L, B, C)
# tgt_mask shape (L, L)
# tgt_key_padding_mask shape (B, L)
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S)
# query_pos shape (L, B, C)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos),
self.with_pos_embed(memory, pos),
memory, attn_mask=memory_mask,
......@@ -232,6 +253,7 @@ class TransformerDecoderLayer(nn.Module):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
# return tgt shape (L, B, C)
return tgt
def forward_pre(self, tgt, memory,
......
......@@ -63,8 +63,11 @@ class MSDeformAttn(nn.Module):
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
# shape (num_heads,)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
# shape (2 * num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# shape (num_heads, num_levels, num_points, 2)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
......
......@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}"
assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), f"incorrect boxes, boxes1 {boxes2}"
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment