"tools/python/vscode:/vscode.git/clone" did not exist on "b6d2329c5e846836d8b42f6ed789918eb88737fb"
Commit 9200cbe8 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Enable Learnable Query TGT

Summary: Learnable query doesn't improve the results, but it helps DETR with reference points in D33420993

Reviewed By: XiaoliangDai

Differential Revision: D33401417

fbshipit-source-id: 5296f2f969c04df18df292d61a7cf57107bc9b74
parent 4985ef73
...@@ -37,6 +37,7 @@ def add_detr_config(cfg): ...@@ -37,6 +37,7 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.WITH_BOX_REFINE = False cfg.MODEL.DETR.WITH_BOX_REFINE = False
cfg.MODEL.DETR.TWO_STAGE = False cfg.MODEL.DETR.TWO_STAGE = False
cfg.MODEL.DETR.DECODER_BLOCK_GRAD = True cfg.MODEL.DETR.DECODER_BLOCK_GRAD = True
# TRANSFORMER # TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8 cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1 cfg.MODEL.DETR.DROPOUT = 0.1
...@@ -49,5 +50,9 @@ def add_detr_config(cfg): ...@@ -49,5 +50,9 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.HIDDEN_DIM = 256 cfg.MODEL.DETR.HIDDEN_DIM = 256
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100 cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100
# solver
cfg.SOLVER.OPTIMIZER = "ADAMW" cfg.SOLVER.OPTIMIZER = "ADAMW"
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
# tgt & embeddings
cfg.MODEL.DETR.LEARNABLE_TGT = False
...@@ -45,6 +45,7 @@ class DETR(nn.Module): ...@@ -45,6 +45,7 @@ class DETR(nn.Module):
num_queries, num_queries,
aux_loss=False, aux_loss=False,
use_focal_loss=False, use_focal_loss=False,
query_embed=None,
): ):
"""Initializes the model. """Initializes the model.
Parameters: Parameters:
...@@ -63,7 +64,11 @@ class DETR(nn.Module): ...@@ -63,7 +64,11 @@ class DETR(nn.Module):
hidden_dim, num_classes if use_focal_loss else num_classes + 1 hidden_dim, num_classes if use_focal_loss else num_classes + 1
) )
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim) self.query_embed = (
query_embed
if query_embed is not None
else nn.Embedding(num_queries, hidden_dim)
)
self.input_proj = nn.Conv2d( self.input_proj = nn.Conv2d(
backbone.num_channels[-1], hidden_dim, kernel_size=1 backbone.num_channels[-1], hidden_dim, kernel_size=1
) )
...@@ -99,13 +104,19 @@ class DETR(nn.Module): ...@@ -99,13 +104,19 @@ class DETR(nn.Module):
num_decoder_layers=dec_layers, num_decoder_layers=dec_layers,
normalize_before=pre_norm, normalize_before=pre_norm,
return_intermediate_dec=deep_supervision, return_intermediate_dec=deep_supervision,
learnable_tgt=cfg.MODEL.DETR.LEARNABLE_TGT,
) )
if cfg.MODEL.DETR.LEARNABLE_TGT:
query_embed = nn.Embedding(num_queries, hidden_dim * 2)
else:
query_embed = nn.Embedding(num_queries, hidden_dim)
return { return {
"backbone": backbone, "backbone": backbone,
"transformer": transformer, "transformer": transformer,
"num_classes": num_classes, "num_classes": num_classes,
"num_queries": num_queries, "num_queries": num_queries,
"query_embed": query_embed,
"aux_loss": deep_supervision, "aux_loss": deep_supervision,
"use_focal_loss": use_focal_loss, "use_focal_loss": use_focal_loss,
} }
......
...@@ -29,6 +29,7 @@ class Transformer(nn.Module): ...@@ -29,6 +29,7 @@ class Transformer(nn.Module):
activation="relu", activation="relu",
normalize_before=False, normalize_before=False,
return_intermediate_dec=False, return_intermediate_dec=False,
learnable_tgt=False,
): ):
super().__init__() super().__init__()
...@@ -55,6 +56,7 @@ class Transformer(nn.Module): ...@@ -55,6 +56,7 @@ class Transformer(nn.Module):
self.d_model = d_model self.d_model = d_model
self.nhead = nhead self.nhead = nhead
self.learnable_tgt = learnable_tgt
def _reset_parameters(self): def _reset_parameters(self):
for p in self.parameters(): for p in self.parameters():
...@@ -71,10 +73,15 @@ class Transformer(nn.Module): ...@@ -71,10 +73,15 @@ class Transformer(nn.Module):
bs, c, h, w = src.shape bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1) # shape (L, B, C) src = src.flatten(2).permute(2, 0, 1) # shape (L, B, C)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (L, B, C) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # shape (L, B, C)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
mask = mask.flatten(1) # shape (B, HxW) mask = mask.flatten(1) # shape (B, HxW)
tgt = torch.zeros_like(query_embed) if self.learnable_tgt:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
tgt = tgt.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
else:
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # shape (M, B, C)
tgt = torch.zeros_like(query_embed)
# memory shape (L, B, C) # memory shape (L, B, C)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# hs shape (NUM_LEVEL, S, B, C) # hs shape (NUM_LEVEL, S, B, C)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment