Commit a4f06b88 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

stabilize deformable DETR training

Summary:
Deformable DETR training can be unstable due to iterative box refinement in the transformer decoder. To stabilize the training, introduce two changes
- Remove the unnecessary use of inverse sigmoid.
It is possible to completely avoid using inverse sigmoid when box refinement is turned on.
- In `DeformableTransformer` class, detach `init_reference_out` before passing it into decoder to update memory and computer per-decoder-layer reference points/

Reviewed By: zhanghang1989

Differential Revision: D29903599

fbshipit-source-id: a374ba161be0d7bcdfb42553044c4c6700e92623
parent 0a458091
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detectron2.layers import ShapeSpec
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks from detectron2.structures import Boxes, ImageList, Instances, BitMasks
from detr.datasets.coco import convert_coco_poly_to_mask from detr.datasets.coco import convert_coco_poly_to_mask
......
...@@ -25,7 +25,6 @@ from ..util.misc import ( ...@@ -25,7 +25,6 @@ from ..util.misc import (
get_world_size, get_world_size,
interpolate, interpolate,
is_dist_avail_and_initialized, is_dist_avail_and_initialized,
inverse_sigmoid,
) )
from .backbone import build_backbone from .backbone import build_backbone
from .deformable_transformer import build_deforamble_transformer from .deformable_transformer import build_deforamble_transformer
...@@ -119,6 +118,7 @@ class DeformableDETR(nn.Module): ...@@ -119,6 +118,7 @@ class DeformableDETR(nn.Module):
self.class_embed.bias.data = torch.ones(num_classes) * bias_value self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj: for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0) nn.init.constant_(proj[0].bias, 0)
...@@ -132,6 +132,7 @@ class DeformableDETR(nn.Module): ...@@ -132,6 +132,7 @@ class DeformableDETR(nn.Module):
if with_box_refine: if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred) self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred) self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
# initialize the box scale height/width at the 1st scale to be 0.1
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement # hack implementation for iterative bounding box refinement
self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.bbox_embed = self.bbox_embed
...@@ -197,6 +198,7 @@ class DeformableDETR(nn.Module): ...@@ -197,6 +198,7 @@ class DeformableDETR(nn.Module):
else: else:
src = self.input_proj[l](srcs[-1]) src = self.input_proj[l](srcs[-1])
b, _, h, w = src.size() b, _, h, w = src.size()
# mask shape (batch_size, h_l, w_l)
mask = F.interpolate(sample_mask, size=src.shape[-2:]).to(torch.bool)[0] mask = F.interpolate(sample_mask, size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src) srcs.append(src)
...@@ -209,7 +211,7 @@ class DeformableDETR(nn.Module): ...@@ -209,7 +211,7 @@ class DeformableDETR(nn.Module):
query_embeds = self.query_embed.weight query_embeds = self.query_embed.weight
# hs shape: (num_layers, batch_size, num_queries, c) # hs shape: (num_layers, batch_size, num_queries, c)
# init_reference shape: (num_queries, 2) # init_reference shape: (batch_size, num_queries, 2)
# inter_references shape: (num_layers, bs, num_queries, num_levels, 2) # inter_references shape: (num_layers, bs, num_queries, num_levels, 2)
( (
hs, hs,
...@@ -222,20 +224,14 @@ class DeformableDETR(nn.Module): ...@@ -222,20 +224,14 @@ class DeformableDETR(nn.Module):
outputs_classes = [] outputs_classes = []
outputs_coords = [] outputs_coords = []
for lvl in range(hs.shape[0]): for lvl in range(hs.shape[0]):
# reference shape: (num_queries, 2)
if lvl == 0: if lvl == 0:
reference = init_reference reference = init_reference
else: else:
reference = inter_references[lvl - 1] reference = inter_references[lvl - 1]
# reference shape: (num_queries, 2)
reference = inverse_sigmoid(reference)
# shape (batch_size, num_queries, num_classes) # shape (batch_size, num_queries, num_classes)
outputs_class = self.class_embed[lvl](hs[lvl]) outputs_class = self.class_embed[lvl](hs[lvl])
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h) # shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(
torch.isnan(hs[lvl])
), f"lvl {lvl}, NaN hs[lvl] {hs[lvl]}"
tmp = self.bbox_embed[lvl](hs[lvl]) tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4: if reference.shape[-1] == 4:
tmp += reference tmp += reference
...@@ -243,15 +239,8 @@ class DeformableDETR(nn.Module): ...@@ -243,15 +239,8 @@ class DeformableDETR(nn.Module):
assert reference.shape[-1] == 2 assert reference.shape[-1] == 2
tmp[..., :2] += reference tmp[..., :2] += reference
# shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h) # shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
assert not torch.any(torch.isnan(tmp)), f"NaN tmp {tmp}"
outputs_coord = tmp.sigmoid() outputs_coord = tmp.sigmoid()
assert not torch.any(
torch.isnan(outputs_coord)
), f"NaN outputs_coord {outputs_coord}"
outputs_classes.append(outputs_class) outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord) outputs_coords.append(outputs_coord)
# shape (num_levels, batch_size, num_queries, num_classes) # shape (num_levels, batch_size, num_queries, num_classes)
...@@ -332,6 +321,13 @@ class MLP(nn.Module): ...@@ -332,6 +321,13 @@ class MLP(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
) )
# initialize FC weights and bias
for i, layer in enumerate(self.layers):
if i < num_layers - 1:
nn.init.kaiming_uniform_(layer.weight, a=1)
else:
nn.init.constant_(layer.weight, 0)
nn.init.constant_(layer.bias, 0)
def forward(self, x): def forward(self, x):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
......
...@@ -16,7 +16,6 @@ from torch import nn ...@@ -16,7 +16,6 @@ from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_ from torch.nn.init import xavier_uniform_, constant_, normal_
from ..modules import MSDeformAttn from ..modules import MSDeformAttn
from ..util.misc import inverse_sigmoid
class DeformableTransformer(nn.Module): class DeformableTransformer(nn.Module):
...@@ -259,7 +258,6 @@ class DeformableTransformer(nn.Module): ...@@ -259,7 +258,6 @@ class DeformableTransformer(nn.Module):
output_memory, output_proposals = self.gen_encoder_output_proposals( output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes memory, mask_flatten, spatial_shapes
) )
# hack implementation for two-stage Deformable DETR # hack implementation for two-stage Deformable DETR
# shape (bs, K, num_classes) # shape (bs, K, num_classes)
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers]( enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
...@@ -270,7 +268,6 @@ class DeformableTransformer(nn.Module): ...@@ -270,7 +268,6 @@ class DeformableTransformer(nn.Module):
self.decoder.bbox_embed[self.decoder.num_layers](output_memory) self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ output_proposals + output_proposals
) )
topk = self.two_stage_num_proposals topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k) # topk_proposals: indices of top items. Shape (bs, top_k)
# TODO (zyan3): use a standalone class_embed layer with 2 output channels? # TODO (zyan3): use a standalone class_embed layer with 2 output channels?
...@@ -280,9 +277,7 @@ class DeformableTransformer(nn.Module): ...@@ -280,9 +277,7 @@ class DeformableTransformer(nn.Module):
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) )
topk_coords_unact = topk_coords_unact.detach() topk_coords_unact = topk_coords_unact.detach()
# reference_points shape (bs, top_k, 4). value \in (0, 1) init_reference_out = reference_points_unact = topk_coords_unact
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
# shape (bs, top_k, C=512) # shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm( pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)) self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
...@@ -297,16 +292,17 @@ class DeformableTransformer(nn.Module): ...@@ -297,16 +292,17 @@ class DeformableTransformer(nn.Module):
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c) # tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1) tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# reference_points shape: (batch_size, num_queries, 2), value \in (0, 1) # init_reference_out shape: (batch_size, num_queries, 2), value \in (0, 1)
reference_points = self.reference_points(query_embed).sigmoid() init_reference_out = self.reference_points(query_embed)
init_reference_out = reference_points # block gradient backpropagation here to stabilize optimization
reference_points_unact = init_reference_out.detach()
# decoder # decoder
# hs shape: (num_layers, batch_size, num_queries, c) # hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2) # inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder( hs, inter_references = self.decoder(
tgt, tgt,
reference_points, reference_points_unact,
memory, memory,
spatial_shapes, spatial_shapes,
level_start_index, level_start_index,
...@@ -562,7 +558,7 @@ class DeformableTransformerDecoder(nn.Module): ...@@ -562,7 +558,7 @@ class DeformableTransformerDecoder(nn.Module):
def forward( def forward(
self, self,
tgt, tgt,
reference_points, reference_points_unact,
src, src,
src_spatial_shapes, src_spatial_shapes,
src_level_start_index, src_level_start_index,
...@@ -573,7 +569,7 @@ class DeformableTransformerDecoder(nn.Module): ...@@ -573,7 +569,7 @@ class DeformableTransformerDecoder(nn.Module):
""" """
Args: Args:
tgt: tensor, shape (batch_size, num_queries, c) tgt: tensor, shape (batch_size, num_queries, c)
reference_points: tensor, shape (batch_size, num_queries, 2 or 4). reference_points_unact: tensor, shape (batch_size, num_queries, 2 or 4).
values \in (0, 1) values \in (0, 1)
src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l
src_spatial_shapes: tensor, shape (num_levels, 2) src_spatial_shapes: tensor, shape (num_levels, 2)
...@@ -587,6 +583,8 @@ class DeformableTransformerDecoder(nn.Module): ...@@ -587,6 +583,8 @@ class DeformableTransformerDecoder(nn.Module):
intermediate = [] intermediate = []
intermediate_reference_points = [] intermediate_reference_points = []
for lid, layer in enumerate(self.layers): for lid, layer in enumerate(self.layers):
reference_points = reference_points_unact.sigmoid()
if reference_points.shape[-1] == 4: if reference_points.shape[-1] == 4:
# shape: (bs, num_queries, 1, 4) * (bs, 1, num_levels, 4) = (bs, num_queries, num_levels, 4) # shape: (bs, num_queries, 1, 4) * (bs, 1, num_levels, 4) = (bs, num_queries, num_levels, 4)
reference_points_input = ( reference_points_input = (
...@@ -609,35 +607,34 @@ class DeformableTransformerDecoder(nn.Module): ...@@ -609,35 +607,34 @@ class DeformableTransformerDecoder(nn.Module):
src_level_start_index, src_level_start_index,
src_padding_mask, src_padding_mask,
) )
assert not torch.any(torch.isnan(output)), f"NaN, lid {lid}, {output}"
# hack implementation for iterative bounding box refinement # hack implementation for iterative bounding box refinement
if self.bbox_embed is not None: if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output) tmp = self.bbox_embed[lid](output)
if reference_points.shape[-1] == 4: if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points) new_reference_points_unact = tmp + reference_points_unact
new_reference_points = new_reference_points.sigmoid()
else: else:
assert reference_points.shape[-1] == 2 assert reference_points.shape[-1] == 2
new_reference_points = tmp new_reference_points_unact = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( new_reference_points_unact[..., :2] = (
reference_points tmp[..., :2] + reference_points_unact
) )
new_reference_points = new_reference_points.sigmoid()
# block gradient backpropagation here to stabilize optimization # block gradient backpropagation here to stabilize optimization
reference_points = new_reference_points.detach() new_reference_points_unact = new_reference_points_unact.detach()
reference_points_unact = new_reference_points_unact
else:
new_reference_points_unact = reference_points_unact
if self.return_intermediate: if self.return_intermediate:
intermediate.append(output) intermediate.append(output)
intermediate_reference_points.append(reference_points) intermediate_reference_points.append(new_reference_points_unact)
if self.return_intermediate: if self.return_intermediate:
# shape 1: (num_layers, batch_size, num_queries, c) # shape 1: (num_layers, batch_size, num_queries, c)
# shape 2: (num_layers, bs, num_queries, num_levels, 2) # shape 2: (num_layers, bs, num_queries, num_levels, 2)
return torch.stack(intermediate), torch.stack(intermediate_reference_points) return torch.stack(intermediate), torch.stack(intermediate_reference_points)
# output shape: (batch_size, num_queries, c) # output shape: (batch_size, num_queries, c)
# reference_points shape: (bs, num_queries, num_levels, 2) # new_reference_points_unact shape: (bs, num_queries, num_levels, 2)
return output, reference_points return output, new_reference_points_unact
def _get_clones(module, N): def _get_clones(module, N):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment