Commit 9200cbe8 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Enable Learnable Query TGT

Summary: Learnable query doesn't improve the results, but it helps DETR with reference points in D33420993

Reviewed By: XiaoliangDai

Differential Revision: D33401417

fbshipit-source-id: 5296f2f969c04df18df292d61a7cf57107bc9b74
parent 4985ef73
......@@ -37,6 +37,7 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.WITH_BOX_REFINE = False
cfg.MODEL.DETR.TWO_STAGE = False
cfg.MODEL.DETR.DECODER_BLOCK_GRAD = True
# TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1
......@@ -49,5 +50,9 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.HIDDEN_DIM = 256
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100
# solver
cfg.SOLVER.OPTIMIZER = "ADAMW"
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
# tgt & embeddings
cfg.MODEL.DETR.LEARNABLE_TGT = False
......@@ -45,6 +45,7 @@ class DETR(nn.Module):
num_queries,
aux_loss=False,
use_focal_loss=False,
query_embed=None,
):
"""Initializes the model.
Parameters:
......@@ -63,7 +64,11 @@ class DETR(nn.Module):
hidden_dim, num_classes if use_focal_loss else num_classes + 1
)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.query_embed = (
query_embed
if query_embed is not None
else nn.Embedding(num_queries, hidden_dim)
)
self.input_proj = nn.Conv2d(
backbone.num_channels[-1], hidden_dim, kernel_size=1
)
......@@ -99,13 +104,19 @@ class DETR(nn.Module):
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
learnable_tgt=cfg.MODEL.DETR.LEARNABLE_TGT,
)
if cfg.MODEL.DETR.LEARNABLE_TGT:
query_embed = nn.Embedding(num_queries, hidden_dim * 2)
else:
query_embed = nn.Embedding(num_queries, hidden_dim)
return {
"backbone": backbone,
"transformer": transformer,
"num_classes": num_classes,
"num_queries": num_queries,
"query_embed": query_embed,
"aux_loss": deep_supervision,
"use_focal_loss": use_focal_loss,
}
......
......@@ -29,6 +29,7 @@ class Transformer(nn.Module):
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
learnable_tgt=False,
):
super().__init__()
......@@ -55,6 +56,7 @@ class Transformer(nn.Module):
self.d_model = d_model
self.nhead = nhead
self.learnable_tgt = learnable_tgt
def _reset_parameters(self):
for p in self.parameters():
......@@ -71,10 +73,15 @@ class Transformer(nn.Module):
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1) # shape (L, B, C)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (L, B, C)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
mask = mask.flatten(1) # shape (B, HxW)
tgt = torch.zeros_like(query_embed)
if self.learnable_tgt:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
tgt = tgt.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
else:
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
tgt = torch.zeros_like(query_embed)
# memory shape (L, B, C)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# hs shape (NUM_LEVEL, S, B, C)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment