"examples/pytorch/test_pytorch_examples.py" did not exist on "76c4d8bf26de3e4ab23b8afeed68479c2bbd9cbd"
Unverified Commit 357971ec authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

fix auxiliary loss training in DetrSegmentation (#28354)

* fix auxiliary loss training in detrSegmentation

* add auxiliary_loss testing
parent 8604dd30
......@@ -1826,9 +1826,9 @@ class DetrForSegmentation(DetrPreTrainedModel):
outputs_loss["pred_masks"] = pred_masks
if self.config.auxiliary_loss:
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
outputs_class = self.class_labels_classifier(intermediate)
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
outputs_class = self.detr.class_labels_classifier(intermediate)
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
auxiliary_outputs = self.detr._set_aux_loss(outputs_class, outputs_coord)
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
loss_dict = criterion(outputs_loss, labels)
......
......@@ -399,6 +399,22 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
def test_forward_auxiliary_loss(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.auxiliary_loss = True
# only test for object detection and segmentation model
for model_class in self.all_model_classes[1:]:
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
outputs = model(**inputs)
self.assertIsNotNone(outputs.auxiliary_outputs)
self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment