Commit aea87f6c authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

fix two-stage DF-DETR

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/106

# 2-stage DF-DETR

DF-DETR supports 2-stage detection. In the 1st stage, we detect class-agnostic boxes using the feature pyramid (a.k.a. `memory` in the code) computed by the encoder.

Current implementation has a few flaws
- In `setcriterion.py`, when computing loss for encoder 1st stage predictions, `num_boxes` should be reduced across gpus and also clamped to be positive integer to avoid divide-by-zero bug. Current implementation will lead to divide-by-zero NaN issue when `num_boxes` is zero (e.g. no box annotation in the cropped input image).
- In `gen_encoder_output_proposals()`, it manually fill in `float("inf")` at invalid spatial positions outside of actual image size. However, it is not guaranteed that those positions won't be selected as top-scored positions.  `float("inf")` can easily cause affected parameters to be updated to NaN value.
- `class_embed` for encoder should has 1 channel rather than num_class channels because we only need to predict the probability of being a foreground box.

This diff fixes the issues above.

# Gradient blocking in decoder

Currently, gradient of reference point is blocked at each decoding layer to improve numerical stability during training.
In this diff, add an option `MODEL.DETR.DECODER_BLOCK_GRAD`. When False, we do NOT block the gradient. Empirically, we find this leads to better box AP.

Reviewed By: zhanghang1989

Differential Revision: D30325396

fbshipit-source-id: 7d7add1e05888adda6e46cc6886117170daa22d4
parent 7677f3ec
......@@ -33,7 +33,7 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1
cfg.MODEL.DETR.WITH_BOX_REFINE = False
cfg.MODEL.DETR.TWO_STAGE = False
cfg.MODEL.DETR.DECODER_BLOCK_GRAD = True
# TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1
......
......@@ -72,6 +72,7 @@ class DeformableDETR(nn.Module):
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
# We will use sigmoid activation and focal loss
self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels
......@@ -116,19 +117,12 @@ class DeformableDETR(nn.Module):
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (
(transformer.decoder.num_layers + 1)
if two_stage
else transformer.decoder.num_layers
)
num_pred = transformer.decoder.num_layers
if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
......@@ -145,10 +139,16 @@ class DeformableDETR(nn.Module):
self.transformer.decoder.bbox_embed = None
if two_stage:
# hack implementation for two-stage
self.transformer.decoder.class_embed = self.class_embed
# We only predict foreground/background at the output of encoder
class_embed = nn.Linear(hidden_dim, 1)
class_embed.bias.data = torch.ones(1) * bias_value
self.transformer.encoder.class_embed = class_embed
for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
self.transformer.encoder.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
def forward(self, samples: NestedTensor):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
......
......@@ -16,6 +16,11 @@ from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_
from ..modules import MSDeformAttn
from ..util.misc import inverse_sigmoid
# we do not use float("-inf") to avoid potential NaN during training
NEG_INF = -10000.0
class DeformableTransformer(nn.Module):
......@@ -34,6 +39,7 @@ class DeformableTransformer(nn.Module):
enc_n_points=4,
two_stage=False,
two_stage_num_proposals=300,
decoder_block_grad=True,
):
super().__init__()
......@@ -63,7 +69,10 @@ class DeformableTransformer(nn.Module):
dec_n_points,
)
self.decoder = DeformableTransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec
decoder_layer,
num_decoder_layers,
return_intermediate_dec,
decoder_block_grad,
)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
......@@ -81,13 +90,13 @@ class DeformableTransformer(nn.Module):
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if not self.two_stage:
xavier_uniform_(self.reference_points.weight.data, gain=1.0)
constant_(self.reference_points.bias.data, 0.0)
xavier_uniform_(self.reference_points.weight, gain=1.0)
constant_(self.reference_points.bias, 0.0)
normal_(self.level_embed)
def get_proposal_pos_embed(self, proposals):
......@@ -163,14 +172,8 @@ class DeformableTransformer(nn.Module):
output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99)
).all(-1, keepdim=True)
# inverse sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
output_proposals = inverse_sigmoid(output_proposals)
# memory: shape (bs, K, C)
output_memory = memory
# memory_padding_mask: shape (bs, K)
......@@ -179,7 +182,7 @@ class DeformableTransformer(nn.Module):
)
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
return output_memory, output_proposals, output_proposals_valid
def get_valid_ratio(self, mask):
_, H, W = mask.shape
......@@ -226,7 +229,7 @@ class DeformableTransformer(nn.Module):
src_flatten = torch.cat(src_flatten, 1)
# mask_flatten shape: (bs, K)
mask_flatten = torch.cat(mask_flatten, 1)
# mask_flatten shape: (bs, K, c)
# lvl_pos_embed_flatten shape: (bs, K, c)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# spatial_shapes shape: (num_levels, 2)
spatial_shapes = torch.as_tensor(
......@@ -238,7 +241,6 @@ class DeformableTransformer(nn.Module):
)
# valid_ratios shape: (bs, num_levels, 2)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
# memory shape (bs, K, C) where K = \sum_l H_l * w_l
memory = self.encoder(
......@@ -255,29 +257,34 @@ class DeformableTransformer(nn.Module):
if self.two_stage:
# output_memory shape (bs, K, C)
# output_proposals shape (bs, K, 4)
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
# output_proposals_valid shape (bs, K, 1)
(
output_memory,
output_proposals,
output_proposals_valid,
) = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# hack implementation for two-stage Deformable DETR
# shape (bs, K, num_classes)
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
output_memory
)
# shape (bs, K, 1)
enc_outputs_class = self.encoder.class_embed(output_memory)
# fill in -inf foreground logit at invalid positions so that we will never pick
# top-scored proposals at those positions
enc_outputs_class.masked_fill(mask_flatten.unsqueeze(-1), NEG_INF)
enc_outputs_class.masked_fill(~output_proposals_valid, NEG_INF)
# shape (bs, K, 4)
enc_outputs_coord_unact = (
self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ output_proposals
self.encoder.bbox_embed(output_memory) + output_proposals
)
topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k)
# TODO (zyan3): use a standalone class_embed layer with 2 output channels?
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
# topk_coords_unact shape (bs, top_k, 4)
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_unact = topk_coords_unact.detach()
init_reference_out = reference_points_unact = topk_coords_unact
init_reference_out = topk_coords_unact
# shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
......@@ -292,17 +299,15 @@ class DeformableTransformer(nn.Module):
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# init_reference_out shape: (batch_size, num_queries, 2), value \in (0, 1)
# init_reference_out shape: (batch_size, num_queries, 2)
init_reference_out = self.reference_points(query_embed)
# block gradient backpropagation here to stabilize optimization
reference_points_unact = init_reference_out.detach()
# decoder
# hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder(
tgt,
reference_points_unact,
init_reference_out,
memory,
spatial_shapes,
level_start_index,
......@@ -546,14 +551,16 @@ class DeformableTransformerDecoderLayer(nn.Module):
class DeformableTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, return_intermediate=False):
def __init__(
self, decoder_layer, num_layers, return_intermediate=False, block_grad=True
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
self.bbox_embed = None
self.class_embed = None
self.block_grad = block_grad
def forward(
self,
......@@ -619,7 +626,9 @@ class DeformableTransformerDecoder(nn.Module):
tmp[..., :2] + reference_points_unact
)
# block gradient backpropagation here to stabilize optimization
new_reference_points_unact = new_reference_points_unact.detach()
if self.block_grad:
new_reference_points_unact = new_reference_points_unact.detach()
reference_points_unact = new_reference_points_unact
else:
new_reference_points_unact = reference_points_unact
......
......@@ -15,6 +15,16 @@ from ..util.misc import (
from .segmentation import dice_loss, sigmoid_focal_loss
def _reduce_num_boxes(targets, device):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
return num_boxes
class SetCriterion(nn.Module):
"""This class computes the loss for DETR.
The process happens in two steps:
......@@ -78,6 +88,50 @@ class SetCriterion(nn.Module):
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
def forground_background_loss_labels(
self, outputs, targets, indices, num_boxes, log=True
):
assert "pred_logits" in outputs
# shape (batch_size, num_queries, 1)
src_logits = outputs["pred_logits"]
batch_size, num_queries = src_logits.shape[:2]
assert src_logits.shape[2] == 1, f"expect 1 class {src_logits.shape[2]}"
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat(
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
)
target_classes = torch.full(
src_logits.shape[:2],
1,
dtype=torch.int64,
device=src_logits.device,
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], 2],
dtype=src_logits.dtype,
layout=src_logits.layout,
device=src_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
sigmoid_focal_loss(
src_logits,
target_classes_onehot,
num_boxes,
alpha=self.focal_alpha,
gamma=2,
)
* src_logits.shape[1]
)
return {"loss_ce": loss_ce}
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
......@@ -181,21 +235,24 @@ class SetCriterion(nn.Module):
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def get_foreground_background_loss(
self, loss, outputs, targets, indices, num_boxes, **kwargs
):
loss_map = {
"labels": self.forground_background_loss_labels,
"cardinality": self.loss_cardinality,
"boxes": self.loss_boxes,
}
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def _forward(self, outputs, outputs_without_aux, targets):
# Retrieve the matching between the outputs of the last layer and the targets
# A list where each item is [row_indices, col_indices]
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor(
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
num_boxes = _reduce_num_boxes(targets, next(iter(outputs.values())).device)
# Compute all the requested losses
losses = {}
for loss in self.losses:
......@@ -257,6 +314,7 @@ class FocalLossSetCriterion(SetCriterion):
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert "pred_logits" in outputs
# shape (batch_size, num_queries, num_classes)
src_logits = outputs["pred_logits"]
idx = self._get_src_permutation_idx(indices)
......@@ -313,8 +371,7 @@ class FocalLossSetCriterion(SetCriterion):
losses = self._forward(outputs, outputs_without_aux, targets)
if "enc_outputs" in outputs:
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = _reduce_num_boxes(targets, next(iter(outputs.values())).device)
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
......@@ -328,7 +385,7 @@ class FocalLossSetCriterion(SetCriterion):
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(
l_dict = self.get_foreground_background_loss(
loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
......
......@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}"
assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), f"incorrect boxes, boxes1 {boxes2}"
assert (boxes1[:, 2:] > boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}"
assert (boxes2[:, 2:] > boxes2[:, :2]).all(), f"incorrect boxes, boxes2 {boxes2}"
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment