Unverified Commit 713eab45 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

🚨 🚨 🚨 [Breaking change]...

🚨 🚨 🚨

 [Breaking change] Deformable DETR intermediate representations (#19678)

* [Breaking change] Deformable DETR intermediate representations

- Fixes naturally the `object-detection` pipeline.
- Moves from `[n_decoders, batch_size, ...]` to `[batch_size,
  n_decoders, ...]` instead.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent fd99ce33
...@@ -135,9 +135,9 @@ class DeformableDetrDecoderOutput(ModelOutput): ...@@ -135,9 +135,9 @@ class DeformableDetrDecoderOutput(ModelOutput):
Args: Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder). Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
Stacked intermediate reference points (reference points of each layer of the decoder). Stacked intermediate reference points (reference points of each layer of the decoder).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
...@@ -171,9 +171,9 @@ class DeformableDetrModelOutput(ModelOutput): ...@@ -171,9 +171,9 @@ class DeformableDetrModelOutput(ModelOutput):
Initial reference points sent through the Transformer decoder. Initial reference points sent through the Transformer decoder.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model. Sequence of hidden-states at the output of the last layer of the decoder of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder). Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder). Stacked intermediate reference points (reference points of each layer of the decoder).
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
...@@ -266,9 +266,9 @@ class DeformableDetrObjectDetectionOutput(ModelOutput): ...@@ -266,9 +266,9 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4, Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4,
4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average 4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average
in the self-attention heads. in the self-attention heads.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder). Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder). Stacked intermediate reference points (reference points of each layer of the decoder).
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Initial reference points sent through the Transformer decoder. Initial reference points sent through the Transformer decoder.
...@@ -1390,8 +1390,9 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): ...@@ -1390,8 +1390,9 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],) all_cross_attentions += (layer_outputs[2],)
intermediate = torch.stack(intermediate) # Keep batch_size as first dimension
intermediate_reference_points = torch.stack(intermediate_reference_points) intermediate = torch.stack(intermediate, dim=1)
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
...@@ -1913,14 +1914,14 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): ...@@ -1913,14 +1914,14 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_classes = [] outputs_classes = []
outputs_coords = [] outputs_coords = []
for level in range(hidden_states.shape[0]): for level in range(hidden_states.shape[1]):
if level == 0: if level == 0:
reference = init_reference reference = init_reference
else: else:
reference = inter_references[level - 1] reference = inter_references[:, level - 1]
reference = inverse_sigmoid(reference) reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[level](hidden_states[level]) outputs_class = self.class_embed[level](hidden_states[:, level])
delta_bbox = self.bbox_embed[level](hidden_states[level]) delta_bbox = self.bbox_embed[level](hidden_states[:, level])
if reference.shape[-1] == 4: if reference.shape[-1] == 4:
outputs_coord_logits = delta_bbox + reference outputs_coord_logits = delta_bbox + reference
elif reference.shape[-1] == 2: elif reference.shape[-1] == 2:
...@@ -1931,11 +1932,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): ...@@ -1931,11 +1932,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_coord = outputs_coord_logits.sigmoid() outputs_coord = outputs_coord_logits.sigmoid()
outputs_classes.append(outputs_class) outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord) outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes) # Keep batch_size as first dimension
outputs_coord = torch.stack(outputs_coords) outputs_class = torch.stack(outputs_classes, dim=1)
outputs_coord = torch.stack(outputs_coords, dim=1)
logits = outputs_class[-1] logits = outputs_class[:, -1]
pred_boxes = outputs_coord[-1] pred_boxes = outputs_coord[:, -1]
loss, loss_dict, auxiliary_outputs = None, None, None loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None: if labels is not None:
...@@ -2000,8 +2002,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): ...@@ -2000,8 +2002,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
intermediate_hidden_states=outputs.intermediate_hidden_states, intermediate_hidden_states=outputs.intermediate_hidden_states,
init_reference_points=outputs.init_reference_points,
intermediate_reference_points=outputs.intermediate_reference_points, intermediate_reference_points=outputs.intermediate_reference_points,
init_reference_points=outputs.init_reference_points,
enc_outputs_class=outputs.enc_outputs_class, enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
) )
......
...@@ -44,12 +44,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -44,12 +44,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor): def get_test_pipeline(self, model, tokenizer, feature_extractor):
if model.__class__.__name__ == "DeformableDetrForObjectDetection":
self.skipTest(
"""Deformable DETR requires a custom CUDA kernel.
"""
)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"] return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment