Commit a4f06b88 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

stabilize deformable DETR training

Summary:
Deformable DETR training can be unstable due to iterative box refinement in the transformer decoder. To stabilize the training, introduce two changes
- Remove the unnecessary use of inverse sigmoid.
It is possible to completely avoid using inverse sigmoid when box refinement is turned on.
- In `DeformableTransformer` class, detach `init_reference_out` before passing it into decoder to update memory and computer per-decoder-layer reference points/

Reviewed By: zhanghang1989

Differential Revision: D29903599

fbshipit-source-id: a374ba161be0d7bcdfb42553044c4c6700e92623
parent 0a458091
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import torch
import torch.nn.functional as F
from detectron2.layers import ShapeSpec
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detr.datasets.coco import convert_coco_poly_to_mask
......
......@@ -25,7 +25,6 @@ from ..util.misc import (
get_world_size,
interpolate,
is_dist_avail_and_initialized,
inverse_sigmoid,
)
from .backbone import build_backbone
from .deformable_transformer import build_deforamble_transformer
......@@ -119,6 +118,7 @@ class DeformableDETR(nn.Module):
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
......@@ -132,6 +132,7 @@ class DeformableDETR(nn.Module):
if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
# initialize the box scale height/width at the 1st scale to be 0.1
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement
self.transformer.decoder.bbox_embed = self.bbox_embed
......@@ -197,6 +198,7 @@ class DeformableDETR(nn.Module):
else:
src = self.input_proj[l](srcs[-1])
b, _, h, w = src.size()
# mask shape (batch_size, h_l, w_l)
mask = F.interpolate(sample_mask, size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
......@@ -209,7 +211,7 @@ class DeformableDETR(nn.Module):
query_embeds = self.query_embed.weight
# hs shape: (num_layers, batch_size, num_queries, c)
# init_reference shape: (num_queries, 2)
# init_reference shape: (batch_size, num_queries, 2)
# inter_references shape: (num_layers, bs, num_queries, num_levels, 2)
(
hs,
......@@ -222,20 +224,14 @@ class DeformableDETR(nn.Module):
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
# reference shape: (num_queries, 2)
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
# reference shape: (num_queries, 2)
reference = inverse_sigmoid(reference)
# shape (batch_size, num_queries, num_classes)
outputs_class = self.class_embed[lvl](hs[lvl])
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(
torch.isnan(hs[lvl])
), f"lvl {lvl}, NaN hs[lvl] {hs[lvl]}"
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
......@@ -243,15 +239,8 @@ class DeformableDETR(nn.Module):
assert reference.shape[-1] == 2
tmp[..., :2] += reference
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(torch.isnan(tmp)), f"NaN tmp {tmp}"
outputs_coord = tmp.sigmoid()
assert not torch.any(
torch.isnan(outputs_coord)
), f"NaN outputs_coord {outputs_coord}"
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
# shape (num_levels, batch_size, num_queries, num_classes)
......@@ -332,6 +321,13 @@ class MLP(nn.Module):
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
# initialize FC weights and bias
for i, layer in enumerate(self.layers):
if i < num_layers - 1:
nn.init.kaiming_uniform_(layer.weight, a=1)
else:
nn.init.constant_(layer.weight, 0)
nn.init.constant_(layer.bias, 0)
def forward(self, x):
for i, layer in enumerate(self.layers):
......
......@@ -16,7 +16,6 @@ from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_
from ..modules import MSDeformAttn
from ..util.misc import inverse_sigmoid
class DeformableTransformer(nn.Module):
......@@ -259,7 +258,6 @@ class DeformableTransformer(nn.Module):
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
# hack implementation for two-stage Deformable DETR
# shape (bs, K, num_classes)
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
......@@ -270,7 +268,6 @@ class DeformableTransformer(nn.Module):
self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ output_proposals
)
topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k)
# TODO (zyan3): use a standalone class_embed layer with 2 output channels?
......@@ -280,9 +277,7 @@ class DeformableTransformer(nn.Module):
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_unact = topk_coords_unact.detach()
# reference_points shape (bs, top_k, 4). value \in (0, 1)
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
init_reference_out = reference_points_unact = topk_coords_unact
# shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
......@@ -297,16 +292,17 @@ class DeformableTransformer(nn.Module):
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# reference_points shape: (batch_size, num_queries, 2), value \in (0, 1)
reference_points = self.reference_points(query_embed).sigmoid()
init_reference_out = reference_points
# init_reference_out shape: (batch_size, num_queries, 2), value \in (0, 1)
init_reference_out = self.reference_points(query_embed)
# block gradient backpropagation here to stabilize optimization
reference_points_unact = init_reference_out.detach()
# decoder
# hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder(
tgt,
reference_points,
reference_points_unact,
memory,
spatial_shapes,
level_start_index,
......@@ -562,7 +558,7 @@ class DeformableTransformerDecoder(nn.Module):
def forward(
self,
tgt,
reference_points,
reference_points_unact,
src,
src_spatial_shapes,
src_level_start_index,
......@@ -573,7 +569,7 @@ class DeformableTransformerDecoder(nn.Module):
"""
Args:
tgt: tensor, shape (batch_size, num_queries, c)
reference_points: tensor, shape (batch_size, num_queries, 2 or 4).
reference_points_unact: tensor, shape (batch_size, num_queries, 2 or 4).
values \in (0, 1)
src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l
src_spatial_shapes: tensor, shape (num_levels, 2)
......@@ -587,6 +583,8 @@ class DeformableTransformerDecoder(nn.Module):
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points = reference_points_unact.sigmoid()
if reference_points.shape[-1] == 4:
# shape: (bs, num_queries, 1, 4) * (bs, 1, num_levels, 4) = (bs, num_queries, num_levels, 4)
reference_points_input = (
......@@ -609,35 +607,34 @@ class DeformableTransformerDecoder(nn.Module):
src_level_start_index,
src_padding_mask,
)
assert not torch.any(torch.isnan(output)), f"NaN, lid {lid}, {output}"
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
new_reference_points_unact = tmp + reference_points_unact
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
reference_points
new_reference_points_unact = tmp
new_reference_points_unact[..., :2] = (
tmp[..., :2] + reference_points_unact
)
new_reference_points = new_reference_points.sigmoid()
# block gradient backpropagation here to stabilize optimization
reference_points = new_reference_points.detach()
new_reference_points_unact = new_reference_points_unact.detach()
reference_points_unact = new_reference_points_unact
else:
new_reference_points_unact = reference_points_unact
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
intermediate_reference_points.append(new_reference_points_unact)
if self.return_intermediate:
# shape 1: (num_layers, batch_size, num_queries, c)
# shape 2: (num_layers, bs, num_queries, num_levels, 2)
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
# output shape: (batch_size, num_queries, c)
# reference_points shape: (bs, num_queries, num_levels, 2)
return output, reference_points
# new_reference_points_unact shape: (bs, num_queries, num_levels, 2)
return output, new_reference_points_unact
def _get_clones(module, N):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment