Commit c480d4e4 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

stabilize the training of deformable DETR with box refinement

Summary:
Major changes
- As described in details in appendix A.4 in deformable DETR paper (https://arxiv.org/abs/2010.04159), the gradient back-propagation is blocked at inverse_sigmoid(bounding box x/y/w/h from last decoder layer). This can be implemented by detaching tensor from compute graph in pytorch. However, currently we detach at an incorrect tensor, preventing update the layers which predicts delta x/y/w/h. Fix this bug.
- Add more comments to annotate data types and tensor shape in the code. This should NOT affect the actual implementation.

Reviewed By: zhanghang1989

Differential Revision: D29048363

fbshipit-source-id: c5b5e89793c86d530b077a7b999769881f441b69
parent 37947353
...@@ -4,29 +4,28 @@ ...@@ -4,29 +4,28 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from detectron2.layers import ShapeSpec from detectron2.layers import ShapeSpec
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detr.datasets.coco import convert_coco_poly_to_mask
from detr.models.backbone import Joiner from detr.models.backbone import Joiner
from detr.models.detr import DETR
from detr.models.deformable_detr import DeformableDETR from detr.models.deformable_detr import DeformableDETR
from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion from detr.models.deformable_transformer import DeformableTransformer
from detr.models.detr import DETR
from detr.models.matcher import HungarianMatcher from detr.models.matcher import HungarianMatcher
from detr.models.position_encoding import PositionEmbeddingSine from detr.models.position_encoding import PositionEmbeddingSine
from detr.models.transformer import Transformer
from detr.models.deformable_transformer import DeformableTransformer
from detr.models.segmentation import DETRsegm, PostProcessSegm from detr.models.segmentation import DETRsegm, PostProcessSegm
from detr.models.setcriterion import SetCriterion, FocalLossSetCriterion
from detr.models.transformer import Transformer
from detr.util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh from detr.util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from detr.util.misc import NestedTensor from detr.util.misc import NestedTensor
from detr.datasets.coco import convert_coco_poly_to_mask from torch import nn
__all__ = ["Detr"] __all__ = ["Detr"]
class ResNetMaskedBackbone(nn.Module): class ResNetMaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking""" """This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
...@@ -48,6 +47,7 @@ class ResNetMaskedBackbone(nn.Module): ...@@ -48,6 +47,7 @@ class ResNetMaskedBackbone(nn.Module):
def forward(self, images): def forward(self, images):
features = self.backbone(images.tensor) features = self.backbone(images.tensor)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW)
masks = self.mask_out_padding( masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()], [features_per_level.shape for features_per_level in features.values()],
images.image_sizes, images.image_sizes,
...@@ -63,7 +63,9 @@ class ResNetMaskedBackbone(nn.Module): ...@@ -63,7 +63,9 @@ class ResNetMaskedBackbone(nn.Module):
assert len(feature_shapes) == len(self.feature_strides) assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes): for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape N, _, H, W = shape
masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes): for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[ masks_per_feature_level[
img_idx, img_idx,
...@@ -73,16 +75,21 @@ class ResNetMaskedBackbone(nn.Module): ...@@ -73,16 +75,21 @@ class ResNetMaskedBackbone(nn.Module):
masks.append(masks_per_feature_level) masks.append(masks_per_feature_level)
return masks return masks
class FBNetMaskedBackbone(nn.Module): class FBNetMaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking""" """This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.backbone = build_backbone(cfg) self.backbone = build_backbone(cfg)
self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES
self.feature_strides = list(self.backbone._out_feature_strides.values()) self.feature_strides = list(self.backbone._out_feature_strides.values())
self.num_channels = [self.backbone._out_feature_channels[k] for k in self.out_features] self.num_channels = [
self.strides = [self.backbone._out_feature_strides[k] for k in self.out_features] self.backbone._out_feature_channels[k] for k in self.out_features
]
self.strides = [
self.backbone._out_feature_strides[k] for k in self.out_features
]
def forward(self, images): def forward(self, images):
features = self.backbone(images.tensor) features = self.backbone(images.tensor)
...@@ -103,7 +110,9 @@ class FBNetMaskedBackbone(nn.Module): ...@@ -103,7 +110,9 @@ class FBNetMaskedBackbone(nn.Module):
assert len(feature_shapes) == len(self.feature_strides) assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes): for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape N, _, H, W = shape
masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes): for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[ masks_per_feature_level[
img_idx, img_idx,
...@@ -147,14 +156,19 @@ class Detr(nn.Module): ...@@ -147,14 +156,19 @@ class Detr(nn.Module):
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
N_steps = hidden_dim // 2 N_steps = hidden_dim // 2
if 'resnet' in cfg.MODEL.BACKBONE.NAME.lower(): if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg) d2_backbone = ResNetMaskedBackbone(cfg)
elif 'fbnet' in cfg.MODEL.BACKBONE.NAME.lower(): elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone =FBNetMaskedBackbone(cfg) d2_backbone = FBNetMaskedBackbone(cfg)
else: else:
raise NotImplementedError raise NotImplementedError
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True, centered=centered_position_encoding)) backbone = Joiner(
d2_backbone,
PositionEmbeddingSine(
N_steps, normalize=True, centered=centered_position_encoding
),
)
backbone.num_channels = d2_backbone.num_channels backbone.num_channels = d2_backbone.num_channels
self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
...@@ -171,13 +185,19 @@ class Detr(nn.Module): ...@@ -171,13 +185,19 @@ class Detr(nn.Module):
num_feature_levels=num_feature_levels, num_feature_levels=num_feature_levels,
dec_n_points=4, dec_n_points=4,
enc_n_points=4, enc_n_points=4,
two_stage=False, two_stage=cfg.MODEL.DETR.TWO_STAGE,
two_stage_num_proposals=num_queries, two_stage_num_proposals=num_queries,
) )
self.detr = DeformableDETR( self.detr = DeformableDETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, backbone,
num_feature_levels=num_feature_levels, aux_loss=deep_supervision, transformer,
num_classes=self.num_classes,
num_queries=num_queries,
num_feature_levels=num_feature_levels,
aux_loss=deep_supervision,
with_box_refine=cfg.MODEL.DETR.WITH_BOX_REFINE,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
) )
else: else:
transformer = Transformer( transformer = Transformer(
...@@ -192,31 +212,41 @@ class Detr(nn.Module): ...@@ -192,31 +212,41 @@ class Detr(nn.Module):
) )
self.detr = DETR( self.detr = DETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, backbone,
aux_loss=deep_supervision, use_focal_loss=self.use_focal_loss, transformer,
num_classes=self.num_classes,
num_queries=num_queries,
aux_loss=deep_supervision,
use_focal_loss=self.use_focal_loss,
) )
if self.mask_on: if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != '': if frozen_weights != "":
print("LOAD pre-trained weights") print("LOAD pre-trained weights")
weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model'] weight = torch.load(
frozen_weights, map_location=lambda storage, loc: storage
)["model"]
new_weight = {} new_weight = {}
for k, v in weight.items(): for k, v in weight.items():
if 'detr.' in k: if "detr." in k:
new_weight[k.replace('detr.', '')] = v new_weight[k.replace("detr.", "")] = v
else: else:
print(f"Skipping loading weight {k} from frozen model") print(f"Skipping loading weight {k} from frozen model")
del weight del weight
self.detr.load_state_dict(new_weight) self.detr.load_state_dict(new_weight)
del new_weight del new_weight
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != '')) self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ""))
self.seg_postprocess = PostProcessSegm self.seg_postprocess = PostProcessSegm
self.detr.to(self.device) self.detr.to(self.device)
# building criterion # building criterion
matcher = HungarianMatcher(cost_class=cls_weight, cost_bbox=l1_weight, matcher = HungarianMatcher(
cost_giou=giou_weight, use_focal_loss=self.use_focal_loss) cost_class=cls_weight,
cost_bbox=l1_weight,
cost_giou=giou_weight,
use_focal_loss=self.use_focal_loss,
)
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight} weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
weight_dict["loss_giou"] = giou_weight weight_dict["loss_giou"] = giou_weight
if deep_supervision: if deep_supervision:
...@@ -229,11 +259,18 @@ class Detr(nn.Module): ...@@ -229,11 +259,18 @@ class Detr(nn.Module):
losses += ["masks"] losses += ["masks"]
if self.use_focal_loss: if self.use_focal_loss:
self.criterion = FocalLossSetCriterion( self.criterion = FocalLossSetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, losses=losses, self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
losses=losses,
) )
else: else:
self.criterion = SetCriterion( self.criterion = SetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses, self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=no_object_weight,
losses=losses,
) )
self.criterion.to(self.device) self.criterion.to(self.device)
...@@ -266,6 +303,9 @@ class Detr(nn.Module): ...@@ -266,6 +303,9 @@ class Detr(nn.Module):
if self.training: if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
# targets: List[Dict[str, torch.Tensor]]. Keys
# "labels": [NUM_BOX,]
# "boxes": [NUM_BOX, 4]
targets = self.prepare_targets(gt_instances) targets = self.prepare_targets(gt_instances)
loss_dict = self.criterion(output, targets) loss_dict = self.criterion(output, targets)
weight_dict = self.criterion.weight_dict weight_dict = self.criterion.weight_dict
...@@ -279,7 +319,9 @@ class Detr(nn.Module): ...@@ -279,7 +319,9 @@ class Detr(nn.Module):
mask_pred = output["pred_masks"] if self.mask_on else None mask_pred = output["pred_masks"] if self.mask_on else None
results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes) results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes)
processed_results = [] processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0]) height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1]) width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width) r = detector_postprocess(results_per_image, height, width)
...@@ -290,15 +332,17 @@ class Detr(nn.Module): ...@@ -290,15 +332,17 @@ class Detr(nn.Module):
new_targets = [] new_targets = []
for targets_per_image in targets: for targets_per_image in targets:
h, w = targets_per_image.image_size h, w = targets_per_image.image_size
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) image_size_xyxy = torch.as_tensor(
gt_classes = targets_per_image.gt_classes [w, h, w, h], dtype=torch.float, device=self.device
)
gt_classes = targets_per_image.gt_classes # shape (NUM_BOX,)
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes) gt_boxes = box_xyxy_to_cxcywh(gt_boxes) # shape (NUM_BOX, 4)
new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
if self.mask_on and hasattr(targets_per_image, 'gt_masks'): if self.mask_on and hasattr(targets_per_image, "gt_masks"):
gt_masks = targets_per_image.gt_masks gt_masks = targets_per_image.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
new_targets[-1].update({'masks': gt_masks}) new_targets[-1].update({"masks": gt_masks})
return new_targets return new_targets
def inference(self, box_cls, box_pred, mask_pred, image_sizes): def inference(self, box_cls, box_pred, mask_pred, image_sizes):
...@@ -321,27 +365,41 @@ class Detr(nn.Module): ...@@ -321,27 +365,41 @@ class Detr(nn.Module):
if self.use_focal_loss: if self.use_focal_loss:
prob = box_cls.sigmoid() prob = box_cls.sigmoid()
# TODO make top-100 as an option for non-focal-loss as well # TODO make top-100 as an option for non-focal-loss as well
scores, topk_indexes = torch.topk(prob.view(box_cls.shape[0], -1), 100, dim=1) scores, topk_indexes = torch.topk(
prob.view(box_cls.shape[0], -1), 100, dim=1
)
topk_boxes = topk_indexes // box_cls.shape[2] topk_boxes = topk_indexes // box_cls.shape[2]
labels = topk_indexes % box_cls.shape[2] labels = topk_indexes % box_cls.shape[2]
else: else:
scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1) scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip( for i, (
scores, labels, box_pred, image_sizes scores_per_image,
)): labels_per_image,
box_pred_per_image,
image_size,
) in enumerate(zip(scores, labels, box_pred, image_sizes)):
result = Instances(image_size) result = Instances(image_size)
boxes = box_cxcywh_to_xyxy(box_pred_per_image) boxes = box_cxcywh_to_xyxy(box_pred_per_image)
if self.use_focal_loss: if self.use_focal_loss:
boxes = torch.gather(boxes.unsqueeze(0), 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)).squeeze() boxes = torch.gather(
boxes.unsqueeze(0), 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)
).squeeze()
result.pred_boxes = Boxes(boxes) result.pred_boxes = Boxes(boxes)
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0])
if self.mask_on: if self.mask_on:
mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False) mask = F.interpolate(
mask_pred[i].unsqueeze(0),
size=image_size,
mode="bilinear",
align_corners=False,
)
mask = mask[0].sigmoid() > 0.5 mask = mask[0].sigmoid() > 0.5
B, N, H, W = mask_pred.shape B, N, H, W = mask_pred.shape
mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32) mask = BitMasks(mask.cpu()).crop_and_resize(
result.pred_boxes.tensor.cpu(), 32
)
result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device) result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device)
result.scores = scores_per_image result.scores = scores_per_image
......
...@@ -128,6 +128,7 @@ class Joiner(nn.Sequential): ...@@ -128,6 +128,7 @@ class Joiner(nn.Sequential):
for x in out: for x in out:
pos.append(self[1](x).to(x.tensors.dtype)) pos.append(self[1](x).to(x.tensors.dtype))
# shape a list of tensors, each tensor shape (B, C, H, W)
return out, pos return out, pos
......
...@@ -10,23 +10,33 @@ ...@@ -10,23 +10,33 @@
""" """
Deformable DETR model and criterion classes. Deformable DETR model and criterion classes.
""" """
import copy
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
import math
from ..util import box_ops from ..util import box_ops
from ..util.misc import (NestedTensor, nested_tensor_from_tensor_list, from ..util.misc import (
accuracy, get_world_size, interpolate, NestedTensor,
is_dist_avail_and_initialized, inverse_sigmoid) nested_tensor_from_tensor_list,
accuracy,
get_world_size,
interpolate,
is_dist_avail_and_initialized,
inverse_sigmoid,
)
from .backbone import build_backbone from .backbone import build_backbone
from .matcher import build_matcher
from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
dice_loss, sigmoid_focal_loss)
from .deformable_transformer import build_deforamble_transformer from .deformable_transformer import build_deforamble_transformer
import copy from .matcher import build_matcher
from .segmentation import (
DETRsegm,
PostProcessPanoptic,
PostProcessSegm,
dice_loss,
sigmoid_focal_loss,
)
from .setcriterion import FocalLossSetCriterion from .setcriterion import FocalLossSetCriterion
...@@ -35,10 +45,20 @@ def _get_clones(module, N): ...@@ -35,10 +45,20 @@ def _get_clones(module, N):
class DeformableDETR(nn.Module): class DeformableDETR(nn.Module):
""" This is the Deformable DETR module that performs object detection """ """This is the Deformable DETR module that performs object detection"""
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
aux_loss=True, with_box_refine=False, two_stage=False): def __init__(
""" Initializes the model. self,
backbone,
transformer,
num_classes,
num_queries,
num_feature_levels,
aux_loss=True,
with_box_refine=False,
two_stage=False,
):
"""Initializes the model.
Parameters: Parameters:
backbone: torch module of the backbone to be used. See backbone.py backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py transformer: torch module of the transformer architecture. See transformer.py
...@@ -57,29 +77,38 @@ class DeformableDETR(nn.Module): ...@@ -57,29 +77,38 @@ class DeformableDETR(nn.Module):
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels self.num_feature_levels = num_feature_levels
if not two_stage: if not two_stage:
self.query_embed = nn.Embedding(num_queries, hidden_dim*2) self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
if num_feature_levels > 1: if num_feature_levels > 1:
num_backbone_outs = len(backbone.strides) num_backbone_outs = len(backbone.strides)
input_proj_list = [] input_proj_list = []
for _ in range(num_backbone_outs): for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_] in_channels = backbone.num_channels[_]
input_proj_list.append(nn.Sequential( input_proj_list.append(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.Sequential(
nn.GroupNorm(32, hidden_dim), nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
)) nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs): for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(nn.Sequential( input_proj_list.append(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.Sequential(
nn.GroupNorm(32, hidden_dim), nn.Conv2d(
)) in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list) self.input_proj = nn.ModuleList(input_proj_list)
else: else:
self.input_proj = nn.ModuleList([ self.input_proj = nn.ModuleList(
nn.Sequential( [
nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), nn.Sequential(
nn.GroupNorm(32, hidden_dim), nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
)]) nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone self.backbone = backbone
self.aux_loss = aux_loss self.aux_loss = aux_loss
self.with_box_refine = with_box_refine self.with_box_refine = with_box_refine
...@@ -95,7 +124,11 @@ class DeformableDETR(nn.Module): ...@@ -95,7 +124,11 @@ class DeformableDETR(nn.Module):
nn.init.constant_(proj[0].bias, 0) nn.init.constant_(proj[0].bias, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation # if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers num_pred = (
(transformer.decoder.num_layers + 1)
if two_stage
else transformer.decoder.num_layers
)
if with_box_refine: if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred) self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred) self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
...@@ -104,7 +137,9 @@ class DeformableDETR(nn.Module): ...@@ -104,7 +137,9 @@ class DeformableDETR(nn.Module):
self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.bbox_embed = self.bbox_embed
else: else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) self.class_embed = nn.ModuleList(
[self.class_embed for _ in range(num_pred)]
)
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.transformer.decoder.bbox_embed = None self.transformer.decoder.bbox_embed = None
if two_stage: if two_stage:
...@@ -114,31 +149,37 @@ class DeformableDETR(nn.Module): ...@@ -114,31 +149,37 @@ class DeformableDETR(nn.Module):
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements: It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries. - "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)] Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as - "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1], (center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding). relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box. See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer. dictionnaries containing the two above keys for each decoder layer.
""" """
if isinstance(samples, (list, torch.Tensor)): if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples) samples = nested_tensor_from_tensor_list(samples)
# features is a list of num_levels NestedTensor.
# pos is a list of num_levels tensors. Each one has shape (B, H_l, W_l).
features, pos = self.backbone(samples) features, pos = self.backbone(samples)
# srcs is a list of num_levels tensor. Each one has shape (B, C, H_l, W_l)
srcs = [] srcs = []
# masks is a list of num_levels tensor. Each one has shape (B, H_l, W_l)
masks = [] masks = []
for l, feat in enumerate(features): for l, feat in enumerate(features):
# src shape: (N, C, H_l, W_l)
# mask shape: (N, H_l, W_l)
src, mask = feat.decompose() src, mask = feat.decompose()
srcs.append(self.input_proj[l](src)) srcs.append(self.input_proj[l](src))
masks.append(mask) masks.append(mask)
assert mask is not None assert mask is not None
if self.num_feature_levels > len(srcs): if self.num_feature_levels > len(srcs):
N, C, H, W = samples.tensor.size() N, C, H, W = samples.tensor.size()
sample_mask = torch.ones((N, H, W), dtype=torch.bool, device=src.device) sample_mask = torch.ones((N, H, W), dtype=torch.bool, device=src.device)
...@@ -146,6 +187,7 @@ class DeformableDETR(nn.Module): ...@@ -146,6 +187,7 @@ class DeformableDETR(nn.Module):
image_size = samples.image_sizes[idx] image_size = samples.image_sizes[idx]
h, w = image_size h, w = image_size
sample_mask[idx, :h, :w] = False sample_mask[idx, :h, :w] = False
# sample_mask shape (1, N, H, W)
sample_mask = sample_mask[None].float() sample_mask = sample_mask[None].float()
_len_srcs = len(srcs) _len_srcs = len(srcs)
...@@ -163,8 +205,19 @@ class DeformableDETR(nn.Module): ...@@ -163,8 +205,19 @@ class DeformableDETR(nn.Module):
query_embeds = None query_embeds = None
if not self.two_stage: if not self.two_stage:
# shape (num_queries, hidden_dim*2)
query_embeds = self.query_embed.weight query_embeds = self.query_embed.weight
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
# hs shape: (num_layers, batch_size, num_queries, c)
# init_reference shape: (num_queries, 2)
# inter_references shape: (num_layers, bs, num_queries, num_levels, 2)
(
hs,
init_reference,
inter_references,
enc_outputs_class,
enc_outputs_coord_unact,
) = self.transformer(srcs, masks, pos, query_embeds)
outputs_classes = [] outputs_classes = []
outputs_coords = [] outputs_coords = []
...@@ -173,27 +226,49 @@ class DeformableDETR(nn.Module): ...@@ -173,27 +226,49 @@ class DeformableDETR(nn.Module):
reference = init_reference reference = init_reference
else: else:
reference = inter_references[lvl - 1] reference = inter_references[lvl - 1]
# reference shape: (num_queries, 2)
reference = inverse_sigmoid(reference) reference = inverse_sigmoid(reference)
# shape (batch_size, num_queries, num_classes)
outputs_class = self.class_embed[lvl](hs[lvl]) outputs_class = self.class_embed[lvl](hs[lvl])
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(
torch.isnan(hs[lvl])
), f"lvl {lvl}, NaN hs[lvl] {hs[lvl]}"
tmp = self.bbox_embed[lvl](hs[lvl]) tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4: if reference.shape[-1] == 4:
tmp += reference tmp += reference
else: else:
assert reference.shape[-1] == 2 assert reference.shape[-1] == 2
tmp[..., :2] += reference tmp[..., :2] += reference
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(torch.isnan(tmp)), f"NaN tmp {tmp}"
outputs_coord = tmp.sigmoid() outputs_coord = tmp.sigmoid()
assert not torch.any(
torch.isnan(outputs_coord)
), f"NaN outputs_coord {outputs_coord}"
outputs_classes.append(outputs_class) outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord) outputs_coords.append(outputs_coord)
# shape (num_levels, batch_size, num_queries, num_classes)
outputs_class = torch.stack(outputs_classes) outputs_class = torch.stack(outputs_classes)
# shape (num_levels, batch_size, num_queries, 4)
outputs_coord = torch.stack(outputs_coords) outputs_coord = torch.stack(outputs_coords)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss: if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
if self.two_stage: if self.two_stage:
enc_outputs_coord = enc_outputs_coord_unact.sigmoid() enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} out["enc_outputs"] = {
"pred_logits": enc_outputs_class,
"pred_boxes": enc_outputs_coord,
}
return out return out
@torch.jit.unused @torch.jit.unused
...@@ -201,53 +276,62 @@ class DeformableDETR(nn.Module): ...@@ -201,53 +276,62 @@ class DeformableDETR(nn.Module):
# this is a workaround to make torchscript happy, as torchscript # this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such # doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list. # as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b} return [
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] {"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
class PostProcess(nn.Module): class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api""" """This module converts the model's output into the format expected by the coco api"""
@torch.no_grad() @torch.no_grad()
def forward(self, outputs, target_sizes): def forward(self, outputs, target_sizes):
""" Perform the computation """Perform the computation
Parameters: Parameters:
outputs: raw outputs of the model outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation) For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding For visualization, this should be the image size after data augment, but before padding
""" """
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes) assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2 assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid() prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1), 100, dim=1
)
scores = topk_values scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2] topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2] labels = topk_indexes % out_logits.shape[2]
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates # and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1) img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :] boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] results = [
{"scores": s, "labels": l, "boxes": b}
for s, l, b in zip(scores, labels, boxes)
]
return results return results
class MLP(nn.Module): class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)""" """Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1) h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
...@@ -256,7 +340,7 @@ class MLP(nn.Module): ...@@ -256,7 +340,7 @@ class MLP(nn.Module):
def build(args): def build(args):
num_classes = 20 if args.dataset_file != 'coco' else 91 num_classes = 20 if args.dataset_file != "coco" else 91
if args.dataset_file == "coco_panoptic": if args.dataset_file == "coco_panoptic":
num_classes = 250 num_classes = 250
device = torch.device(args.device) device = torch.device(args.device)
...@@ -277,8 +361,8 @@ def build(args): ...@@ -277,8 +361,8 @@ def build(args):
if args.masks: if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args) matcher = build_matcher(args)
weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef} weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef weight_dict["loss_giou"] = args.giou_loss_coef
if args.masks: if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef weight_dict["loss_dice"] = args.dice_loss_coef
...@@ -286,21 +370,25 @@ def build(args): ...@@ -286,21 +370,25 @@ def build(args):
if args.aux_loss: if args.aux_loss:
aux_weight_dict = {} aux_weight_dict = {}
for i in range(args.dec_layers - 1): for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict) weight_dict.update(aux_weight_dict)
losses = ['labels', 'boxes', 'cardinality'] losses = ["labels", "boxes", "cardinality"]
if args.masks: if args.masks:
losses += ["masks"] losses += ["masks"]
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25 # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
criterion = FocalLossSetCriterion(num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha) criterion = FocalLossSetCriterion(
num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha
)
criterion.to(device) criterion.to(device)
postprocessors = {'bbox': PostProcess()} postprocessors = {"bbox": PostProcess()}
if args.masks: if args.masks:
postprocessors['segm'] = PostProcessSegm() postprocessors["segm"] = PostProcessSegm()
if args.dataset_file == "coco_panoptic": if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)} is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) postprocessors["panoptic"] = PostProcessPanoptic(
is_thing_map, threshold=0.85
)
return model, criterion, postprocessors return model, criterion, postprocessors
...@@ -63,12 +63,18 @@ class DETR(nn.Module): ...@@ -63,12 +63,18 @@ class DETR(nn.Module):
samples = nested_tensor_from_tensor_list(samples) samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples) features, pos = self.backbone(samples)
# src shape (B, C, H, W)
# mask shape (B, H, W)
src, mask = features[-1].decompose() src, mask = features[-1].decompose()
assert mask is not None assert mask is not None
# hs shape (NUM_LAYER, B, S, hidden_dim)
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
# shape (NUM_LAYER, B, S, NUM_CLASS + 1)
outputs_class = self.class_embed(hs) outputs_class = self.class_embed(hs)
# shape (NUM_LAYER, B, S, 4)
outputs_coord = self.bbox_embed(hs).sigmoid() outputs_coord = self.bbox_embed(hs).sigmoid()
# pred_logits shape (B, S, NUM_CLASS + 1)
# pred_boxes shape (B, S, 4)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss: if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
......
...@@ -65,8 +65,8 @@ class HungarianMatcher(nn.Module): ...@@ -65,8 +65,8 @@ class HungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes # Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets]) tgt_ids = torch.cat([v["labels"] for v in targets]) # [\sum_b NUM-BOX_b,]
tgt_bbox = torch.cat([v["boxes"] for v in targets]) tgt_bbox = torch.cat([v["boxes"] for v in targets]) # [\sum_b NUM-BOX_b, 4]
# Compute the classification cost. Contrary to the loss, we don't use the NLL, # Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class]. # but approximate it in 1 - proba[target class].
...@@ -78,20 +78,23 @@ class HungarianMatcher(nn.Module): ...@@ -78,20 +78,23 @@ class HungarianMatcher(nn.Module):
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
else: else:
cost_class = -out_prob[:, tgt_ids] cost_class = -out_prob[:, tgt_ids] # shape [batch_size * num_queries, \sum_b NUM-BOX_b]
# Compute the L1 cost between boxes # Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # shape [batch_size * num_queries,\sum_b NUM-BOX_b]
# Compute the giou cost betwen boxes # Compute the giou cost betwen boxes
# shape [batch_size * num_queries, \sum_b NUM-BOX_b]
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix # Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu() C = C.view(bs, num_queries, -1).cpu() # shape [batch_size, num_queries, \sum_b NUM-BOX_b]
sizes = [len(v["boxes"]) for v in targets] sizes = [len(v["boxes"]) for v in targets] # shape [batch_size,]
# each split c shape [batch_size, num_queries, NUM-BOX_b]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
# A list where each item is [row_indices, col_indices]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
......
...@@ -33,7 +33,7 @@ class PositionEmbeddingSine(nn.Module): ...@@ -33,7 +33,7 @@ class PositionEmbeddingSine(nn.Module):
mask = tensor_list.mask mask = tensor_list.mask
assert mask is not None assert mask is not None
not_mask = ~mask not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(1, dtype=torch.float32) # shape (B, H, W)
x_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize: if self.normalize:
eps = 1e-6 eps = 1e-6
...@@ -45,13 +45,13 @@ class PositionEmbeddingSine(nn.Module): ...@@ -45,13 +45,13 @@ class PositionEmbeddingSine(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # shape (N, )
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t # shape (B, H, W, N)
pos_y = y_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # shape (B, H, W, N)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # shape (B, 2*N, H, W)
return pos return pos
......
...@@ -39,10 +39,18 @@ class SetCriterion(nn.Module): ...@@ -39,10 +39,18 @@ class SetCriterion(nn.Module):
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
""" """
assert 'pred_logits' in outputs assert 'pred_logits' in outputs
# shape (batch_size, num_queries, NUM_CLASS + 1)
src_logits = outputs['pred_logits'] src_logits = outputs['pred_logits']
# idx = (batch_idx, src_idx)
# batch_idx shape [\sum_b num_match_b]
# src_idx shape [\sum_b num_match_b]
idx = self._get_src_permutation_idx(indices) idx = self._get_src_permutation_idx(indices)
# targets: List[Dict[str, torch.Tensor]]. Keys
# "labels": [NUM_BOX,]
# "boxes": [NUM_BOX, 4]
# target_classes_o shape [batch_size * num_match]
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
# shape (batch_size, num_queries)
target_classes = torch.full(src_logits.shape[:2], self.num_classes, target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device) dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
...@@ -76,7 +84,9 @@ class SetCriterion(nn.Module): ...@@ -76,7 +84,9 @@ class SetCriterion(nn.Module):
""" """
assert 'pred_boxes' in outputs assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices) idx = self._get_src_permutation_idx(indices)
# shape [\sum_b num_matches_b, 4]
src_boxes = outputs['pred_boxes'][idx] src_boxes = outputs['pred_boxes'][idx]
# shape [\sum_b num_matches_b, 4]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
...@@ -121,14 +131,14 @@ class SetCriterion(nn.Module): ...@@ -121,14 +131,14 @@ class SetCriterion(nn.Module):
def _get_src_permutation_idx(self, indices): def _get_src_permutation_idx(self, indices):
# permute predictions following indices # permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) # shape [\sum_b num_match_b]
src_idx = torch.cat([src for (src, _) in indices]) src_idx = torch.cat([src for (src, _) in indices]) # shape [\sum_b num_match_b]
return batch_idx, src_idx return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices): def _get_tgt_permutation_idx(self, indices):
# permute targets following indices # permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) # shape [\sum_b num_match_b]
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) # shape [\sum_b num_match_b]
return batch_idx, tgt_idx return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
...@@ -148,9 +158,12 @@ class SetCriterion(nn.Module): ...@@ -148,9 +158,12 @@ class SetCriterion(nn.Module):
targets: list of dicts, such that len(targets) == batch_size. targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc The expected keys in each dict depends on the losses applied, see each loss' doc
""" """
# "pred_logits" shape (B, S, NUM_CLASS + 1)
# "pred_boxes" shape (B, S, 4)
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets # Retrieve the matching between the outputs of the last layer and the targets
# A list where each item is [row_indices, col_indices]
indices = self.matcher(outputs_without_aux, targets) indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes # Compute the average number of target boxes accross all nodes, for normalization purposes
...@@ -378,5 +391,3 @@ class FocalLossSetCriterion(nn.Module): ...@@ -378,5 +391,3 @@ class FocalLossSetCriterion(nn.Module):
losses.update(l_dict) losses.update(l_dict)
return losses return losses
...@@ -47,17 +47,25 @@ class Transformer(nn.Module): ...@@ -47,17 +47,25 @@ class Transformer(nn.Module):
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed): def forward(self, src, mask, query_embed, pos_embed):
# src shape (B, C, H, W)
# mask shape (B, H, W)
# query_embed shape (M, C)
# pos_embed shape (B, C, H, W)
# flatten NxCxHxW to HWxNxC # flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1) src = src.flatten(2).permute(2, 0, 1) # shape (L, B, C)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (L, B, C)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
mask = mask.flatten(1) mask = mask.flatten(1) # shape (B, HxW)
tgt = torch.zeros_like(query_embed) tgt = torch.zeros_like(query_embed)
# memory shape (L, B, C)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# hs shape (NUM_LEVEL, S, B, C)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed) pos=pos_embed, query_pos=query_embed)
# return shape (NUM_LEVEL, B, S, C) and (B, C, H, W)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
...@@ -74,7 +82,8 @@ class TransformerEncoder(nn.Module): ...@@ -74,7 +82,8 @@ class TransformerEncoder(nn.Module):
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None): pos: Optional[Tensor] = None):
output = src output = src
# mask, shape (L, L)
# src_key_padding_mask, shape (B, L)
for layer in self.layers: for layer in self.layers:
output = layer(output, src_mask=mask, output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos) src_key_padding_mask=src_key_padding_mask, pos=pos)
...@@ -104,7 +113,11 @@ class TransformerDecoder(nn.Module): ...@@ -104,7 +113,11 @@ class TransformerDecoder(nn.Module):
output = tgt output = tgt
intermediate = [] intermediate = []
# tgt shape (L, B, C)
# tgt_mask shape (L, L)
# tgt_key_padding_mask shape (B, L)
# memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S)
for layer in self.layers: for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask, output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask, memory_mask=memory_mask,
...@@ -122,7 +135,7 @@ class TransformerDecoder(nn.Module): ...@@ -122,7 +135,7 @@ class TransformerDecoder(nn.Module):
if self.return_intermediate: if self.return_intermediate:
return torch.stack(intermediate) return torch.stack(intermediate)
# return shape (NUM_LAYER, L, B, C)
return output.unsqueeze(0) return output.unsqueeze(0)
...@@ -153,7 +166,9 @@ class TransformerEncoderLayer(nn.Module): ...@@ -153,7 +166,9 @@ class TransformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None): pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos) q = k = self.with_pos_embed(src, pos) # shape (L, B, D)
# src mask, shape (L, L)
# src_key_padding_mask: shape (B, L)
src2 = self.self_attn(q, k, src, attn_mask=src_mask, src2 = self.self_attn(q, k, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0] key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
...@@ -218,11 +233,17 @@ class TransformerDecoderLayer(nn.Module): ...@@ -218,11 +233,17 @@ class TransformerDecoderLayer(nn.Module):
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None, pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): query_pos: Optional[Tensor] = None):
# tgt shape (L, B, C)
# tgt_mask shape (L, L)
# tgt_key_padding_mask shape (B, L)
q = k = self.with_pos_embed(tgt, query_pos) q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask, tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0] key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
# memory_mask shape (L, S)
# memory_key_padding_mask shape (B, S)
# query_pos shape (L, B, C)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos), tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos),
self.with_pos_embed(memory, pos), self.with_pos_embed(memory, pos),
memory, attn_mask=memory_mask, memory, attn_mask=memory_mask,
...@@ -232,6 +253,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -232,6 +253,7 @@ class TransformerDecoderLayer(nn.Module):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2) tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)
# return tgt shape (L, B, C)
return tgt return tgt
def forward_pre(self, tgt, memory, def forward_pre(self, tgt, memory,
......
...@@ -63,8 +63,11 @@ class MSDeformAttn(nn.Module): ...@@ -63,8 +63,11 @@ class MSDeformAttn(nn.Module):
def _reset_parameters(self): def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.) constant_(self.sampling_offsets.weight.data, 0.)
# shape (num_heads,)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
# shape (2 * num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# shape (num_heads, num_levels, num_points, 2)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points): for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1 grid_init[:, :, i, :] *= i + 1
......
...@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2):
""" """
# degenerate boxes gives inf / nan results # degenerate boxes gives inf / nan results
# so do an early check # so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}"
assert (boxes2[:, 2:] >= boxes2[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), f"incorrect boxes, boxes1 {boxes2}"
iou, union = box_iou(boxes1, boxes2) iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment