Commit e69e0ffe authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

revert D29048363

Summary:
In D29048363 (https://github.com/facebookresearch/d2go/commit/c480d4e4e213a850cced7758f7b62c20caad8820) we make the detaching of `reference_points` earlier in the hope of allowing more gradient flow to update weights in `self.bbox_embed`.
In this diff, we revert the changes as i) it does not improve box AP ii) it may potential cause in-stable optimization when iterative box refinement is turned on.

Reviewed By: zhanghang1989

Differential Revision: D29530735

fbshipit-source-id: 3217c863343836e129d53e07c0eedb2db8164fe6
parent ff9d5d38
......@@ -8,8 +8,6 @@
# ------------------------------------------------------------------------
import copy
# import logging
import math
import torch
......@@ -125,9 +123,9 @@ class DeformableTransformer(nn.Module):
spatial_shapes: shape (num_levels, 2)
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
base_object_scale = 0.05
for lvl, (H_, W_) in enumerate(spatial_shapes):
# shape (bs, H_l * W_l)
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
......@@ -155,7 +153,7 @@ class DeformableTransformer(nn.Module):
# grid shape (bs, H_l, W_l, 2). Value could be > 1
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh shape (bs, H_l, W_l, 2)
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
wh = torch.ones_like(grid) * base_object_scale * (2.0 ** lvl)
# proposal shape (bs, H_l * W_l, 4)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
......@@ -174,8 +172,9 @@ class DeformableTransformer(nn.Module):
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
# memory: shape (bs, K, C)
output_memory = memory
# memory_padding_mask: shape (bs, K)
output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0)
)
......@@ -238,7 +237,7 @@ class DeformableTransformer(nn.Module):
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
# spatial_shapes shape: (bs, num_levels, 2)
# valid_ratios shape: (bs, num_levels, 2)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
......@@ -255,7 +254,7 @@ class DeformableTransformer(nn.Module):
# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
# output_memory shape (bs, K, C). Value = 0
# output_memory shape (bs, K, C)
# output_proposals shape (bs, K, 4)
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
......@@ -274,6 +273,7 @@ class DeformableTransformer(nn.Module):
topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k)
# TODO (zyan3): use a standalone class_embed layer with 2 output channels?
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
# topk_coords_unact shape (bs, top_k, 4)
topk_coords_unact = torch.gather(
......@@ -433,9 +433,6 @@ class DeformableTransformerEncoder(nn.Module):
# shape (N, K, 1, 2) * (N, 1, num_levels, 2) = (N, K, num_levels, 2)
# value should be <1
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
assert not torch.any(
torch.isnan(reference_points)
), f"nan in reference_points {reference_points}"
return reference_points
def forward(
......@@ -519,7 +516,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
Args:
tgt: tensor, shape (batch_size, num_queries, c)
query_pos: tensor, shape: (batch_size, num_queries, c)
reference_points: tensor, shape: (batch_size, num_queries, num_levels, 2). values \in (0, 1)
reference_points: tensor, shape: (batch_size, num_queries, num_levels, 2/4). values \in (0, 1)
src: tensor, shape (batch_size, K, c) where K = \sum_l H_l * w_l
src_spatial_shapes: tensor, shape (num_levels, 2)
level_start_index: tensor, shape (num_levels,)
......@@ -617,23 +614,18 @@ class DeformableTransformerDecoder(nn.Module):
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output)
reference_points_unact = inverse_sigmoid(reference_points)
# block gradient backpropagation here to avoid instable optimization
reference_points_unact_detach = reference_points_unact.detach()
if reference_points.shape[-1] == 4:
new_reference_points = tmp + reference_points_unact_detach
assert not torch.any(
torch.isnan(new_reference_points)
), f"NaN, reference_points {reference_points}, new_reference_points {new_reference_points}"
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + reference_points_unact_detach
reference_points = new_reference_points.sigmoid()
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
reference_points
)
new_reference_points = new_reference_points.sigmoid()
# block gradient backpropagation here to stabilize optimization
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment