"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3042c63a9521d7d1c23a3127f78878f7288e18a5"
Unverified Commit 8021c684 authored by Yupeng Jia's avatar Yupeng Jia Committed by GitHub
Browse files

Fix some bugs for two stage training of deformable detr (#25045)



* Update modeling_deformable_detr.py

Fix bugs for two stage training

* Update modeling_deformable_detr.py

* Add test_two_stage_training to DeformableDetrModelTest

---------
Co-authored-by: default avataryupeng.jia <yupeng.jia@momenta.ai>
parent 1b354097
......@@ -1558,7 +1558,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
def get_proposal_pos_embed(self, proposals):
"""Get the position embedding of the proposals."""
num_pos_feats = 128
num_pos_feats = self.config.d_model // 2
temperature = 10000
scale = 2 * math.pi
......@@ -1977,12 +1977,11 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_coord = outputs_coord_logits.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
# Keep batch_size as first dimension
outputs_class = torch.stack(outputs_classes, dim=1)
outputs_coord = torch.stack(outputs_coords, dim=1)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
logits = outputs_class[:, -1]
pred_boxes = outputs_coord[:, -1]
logits = outputs_class[-1]
pred_boxes = outputs_coord[-1]
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
......@@ -2008,7 +2007,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
if self.config.two_stage:
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
outputs_loss["enc_outputs"] = {"logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
......@@ -2240,7 +2239,7 @@ class DeformableDetrLoss(nn.Module):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
......@@ -2272,14 +2271,10 @@ class DeformableDetrLoss(nn.Module):
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
......
......@@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
def get_proposal_pos_embed(self, proposals):
"""Get the position embedding of the proposals."""
num_pos_feats = 128
num_pos_feats = self.config.d_model // 2
temperature = 10000
scale = 2 * math.pi
......
......@@ -544,6 +544,21 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_two_stage_training(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
config.two_stage = True
config.auxiliary_loss = True
config.with_box_refine = True
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
TOLERANCE = 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment