Unverified Commit 99c3d449 authored by Lorenzo Battistela's avatar Lorenzo Battistela Committed by GitHub
Browse files

fixing name position_embeddings to object_queries (#24652)



* fixing name position_embeddings to object_queries

* [fix] renaming variable and docstring do object queries

* [fix] comment position_embedding to object queries

* [feat] changes from make-fix-copies to keep consistency

* Revert "[feat] changes from make-fix-copies to keep consistency"

This reverts commit 56e3e9ede1d32f7aeefba707ddfaf12c9b4b9e7e.

* [tests] fix wrong expected score

* [fix] wrong assignment causing wrong tensor shapes

* [fix] fixing position_embeddings to object queries to keep consistency (make fix copies)

* [fix] make fix copies, renaming position_embeddings to object_queries

* [fix] positional_embeddingss to object queries, fixes from make fix copies

* [fix] comments frmo make fix copies

* [fix] adding args validation to keep version support

* [fix] adding args validation to keep version support -conditional detr

* [fix] adding args validation to keep version support - maskformer

* [style] make fixup style fixes

* [feat] adding args checking

* [feat] fixcopies and args checking

* make fixup

* make fixup

---------
Co-authored-by: default avatarLorenzobattistela <lorenzobattistela@gmail.com>
parent 39c37fe4
...@@ -437,34 +437,79 @@ class DetrAttention(nn.Module): ...@@ -437,34 +437,79 @@ class DetrAttention(nn.Module):
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
return tensor if position_embeddings is None else tensor + position_embeddings position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
return tensor if object_queries is None else tensor + object_queries
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, object_queries: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None, spatial_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
position_embeddings = kwargs.pop("position_ebmeddings", None)
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
raise ValueError(
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
if key_value_position_embeddings is not None:
logger.warning_once(
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
)
spatial_position_embeddings = key_value_position_embeddings
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
batch_size, target_len, embed_dim = hidden_states.size() batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys # add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None: if object_queries is not None:
hidden_states_original = hidden_states hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings) hidden_states = self.with_pos_embed(hidden_states, object_queries)
# add key-value position embeddings to the key value states # add key-value position embeddings to the key value states
if key_value_position_embeddings is not None: if spatial_position_embeddings is not None:
key_value_states_original = key_value_states key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scaling query_states = self.q_proj(hidden_states) * self.scaling
...@@ -563,11 +608,12 @@ class DetrDecoderLayer(nn.Module): ...@@ -563,11 +608,12 @@ class DetrDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, object_queries: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
): ):
""" """
Args: Args:
...@@ -575,8 +621,8 @@ class DetrDecoderLayer(nn.Module): ...@@ -575,8 +621,8 @@ class DetrDecoderLayer(nn.Module):
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values. values.
position_embeddings (`torch.FloatTensor`, *optional*): object_queries (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys object_queries that are added to the hidden states
in the cross-attention layer. in the cross-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*): query_position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys position embeddings that are added to the queries and keys
...@@ -590,12 +636,28 @@ class DetrDecoderLayer(nn.Module): ...@@ -590,12 +636,28 @@ class DetrDecoderLayer(nn.Module):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
""" """
position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
residual = hidden_states residual = hidden_states
# Self Attention # Self Attention
hidden_states, self_attn_weights = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=query_position_embeddings, object_queries=query_position_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -611,10 +673,10 @@ class DetrDecoderLayer(nn.Module): ...@@ -611,10 +673,10 @@ class DetrDecoderLayer(nn.Module):
hidden_states, cross_attn_weights = self.encoder_attn( hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=query_position_embeddings, object_queries=query_position_embeddings,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
key_value_position_embeddings=position_embeddings, spatial_position_embeddings=object_queries,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -662,7 +724,7 @@ class DetrDecoder(nn.Module): ...@@ -662,7 +724,7 @@ class DetrDecoder(nn.Module):
Some small tweaks for DETR: Some small tweaks for DETR:
- position_embeddings and query_position_embeddings are added to the forward pass. - object_queries and query_position_embeddings are added to the forward pass.
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
Args: Args:
...@@ -687,11 +749,12 @@ class DetrDecoder(nn.Module): ...@@ -687,11 +749,12 @@ class DetrDecoder(nn.Module):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
position_embeddings=None, object_queries=None,
query_position_embeddings=None, query_position_embeddings=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs,
): ):
r""" r"""
Args: Args:
...@@ -715,7 +778,7 @@ class DetrDecoder(nn.Module): ...@@ -715,7 +778,7 @@ class DetrDecoder(nn.Module):
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**). - 0 for pixels that are padding (i.e. **masked**).
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Position embeddings that are added to the queries and keys in each cross-attention layer. Position embeddings that are added to the queries and keys in each cross-attention layer.
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
...@@ -728,6 +791,21 @@ class DetrDecoder(nn.Module): ...@@ -728,6 +791,21 @@ class DetrDecoder(nn.Module):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -788,7 +866,7 @@ class DetrDecoder(nn.Module): ...@@ -788,7 +866,7 @@ class DetrDecoder(nn.Module):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
position_embeddings=position_embeddings, object_queries=object_queries,
query_position_embeddings=query_position_embeddings, query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
...@@ -1438,23 +1516,23 @@ class MaskFormerTransformerModule(nn.Module): ...@@ -1438,23 +1516,23 @@ class MaskFormerTransformerModule(nn.Module):
) -> DetrDecoderOutput: ) -> DetrDecoderOutput:
if self.input_projection is not None: if self.input_projection is not None:
image_features = self.input_projection(image_features) image_features = self.input_projection(image_features)
position_embeddings = self.position_embedder(image_features) object_queries = self.position_embedder(image_features)
# repeat the queries "q c -> b q c" # repeat the queries "q c -> b q c"
batch_size = image_features.shape[0] batch_size = image_features.shape[0]
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1) queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True) inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
batch_size, num_channels, height, width = image_features.shape batch_size, num_channels, height, width = image_features.shape
# rearrange both image_features and position_embeddings "b c h w -> b (h w) c" # rearrange both image_features and object_queries "b c h w -> b (h w) c"
image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1) image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
position_embeddings = position_embeddings.view(batch_size, num_channels, height * width).permute(0, 2, 1) object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1)
decoder_output: DetrDecoderOutput = self.decoder( decoder_output: DetrDecoderOutput = self.decoder(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=None, attention_mask=None,
encoder_hidden_states=image_features, encoder_hidden_states=image_features,
encoder_attention_mask=None, encoder_attention_mask=None,
position_embeddings=position_embeddings, object_queries=object_queries,
query_position_embeddings=queries_embeddings, query_position_embeddings=queries_embeddings,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -469,34 +469,79 @@ class TableTransformerAttention(nn.Module): ...@@ -469,34 +469,79 @@ class TableTransformerAttention(nn.Module):
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
return tensor if position_embeddings is None else tensor + position_embeddings position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
return tensor if object_queries is None else tensor + object_queries
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, object_queries: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None, spatial_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
position_embeddings = kwargs.pop("position_ebmeddings", None)
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
raise ValueError(
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
if key_value_position_embeddings is not None:
logger.warning_once(
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
)
spatial_position_embeddings = key_value_position_embeddings
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
batch_size, target_len, embed_dim = hidden_states.size() batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys # add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None: if object_queries is not None:
hidden_states_original = hidden_states hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings) hidden_states = self.with_pos_embed(hidden_states, object_queries)
# add key-value position embeddings to the key value states # add key-value position embeddings to the key value states
if key_value_position_embeddings is not None: if spatial_position_embeddings is not None:
key_value_states_original = key_value_states key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scaling query_states = self.q_proj(hidden_states) * self.scaling
...@@ -587,7 +632,7 @@ class TableTransformerEncoderLayer(nn.Module): ...@@ -587,7 +632,7 @@ class TableTransformerEncoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None, object_queries: torch.Tensor = None,
output_attentions: bool = False, output_attentions: bool = False,
): ):
""" """
...@@ -596,7 +641,7 @@ class TableTransformerEncoderLayer(nn.Module): ...@@ -596,7 +641,7 @@ class TableTransformerEncoderLayer(nn.Module):
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values. values.
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states. object_queries (`torch.FloatTensor`, *optional*): object queries, to be added to hidden_states.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -607,7 +652,7 @@ class TableTransformerEncoderLayer(nn.Module): ...@@ -607,7 +652,7 @@ class TableTransformerEncoderLayer(nn.Module):
hidden_states, attn_weights = self.self_attn( hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_embeddings=position_embeddings, object_queries=object_queries,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -668,7 +713,7 @@ class TableTransformerDecoderLayer(nn.Module): ...@@ -668,7 +713,7 @@ class TableTransformerDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, object_queries: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None, query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
...@@ -680,11 +725,11 @@ class TableTransformerDecoderLayer(nn.Module): ...@@ -680,11 +725,11 @@ class TableTransformerDecoderLayer(nn.Module):
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values. values.
position_embeddings (`torch.FloatTensor`, *optional*): object_queries (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys object queries that are added to the queries and keys
in the cross-attention layer. in the cross-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*): query_position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys object queries that are added to the queries and keys
in the self-attention layer. in the self-attention layer.
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)` cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -701,7 +746,7 @@ class TableTransformerDecoderLayer(nn.Module): ...@@ -701,7 +746,7 @@ class TableTransformerDecoderLayer(nn.Module):
# Self Attention # Self Attention
hidden_states, self_attn_weights = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=query_position_embeddings, object_queries=query_position_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -717,10 +762,10 @@ class TableTransformerDecoderLayer(nn.Module): ...@@ -717,10 +762,10 @@ class TableTransformerDecoderLayer(nn.Module):
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
hidden_states, cross_attn_weights = self.encoder_attn( hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=query_position_embeddings, object_queries=query_position_embeddings,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
key_value_position_embeddings=position_embeddings, spatial_position_embeddings=object_queries,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -854,7 +899,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): ...@@ -854,7 +899,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
Small tweak for Table Transformer: Small tweak for Table Transformer:
- position_embeddings are added to the forward pass. - object_queries are added to the forward pass.
Args: Args:
config: TableTransformerConfig config: TableTransformerConfig
...@@ -877,7 +922,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): ...@@ -877,7 +922,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
self, self,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
position_embeddings=None, object_queries=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -895,7 +940,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): ...@@ -895,7 +940,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer. Position embeddings that are added to the queries and keys in each self-attention layer.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
...@@ -936,11 +981,11 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): ...@@ -936,11 +981,11 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
if to_drop: if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
# we add position_embeddings as extra input to the encoder_layer # we add object_queries as extra input to the encoder_layer
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
position_embeddings=position_embeddings, object_queries=object_queries,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -970,7 +1015,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -970,7 +1015,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
Some small tweaks for TABLE_TRANSFORMER: Some small tweaks for TABLE_TRANSFORMER:
- position_embeddings and query_position_embeddings are added to the forward pass. - object_queries and query_position_embeddings are added to the forward pass.
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
Args: Args:
...@@ -996,11 +1041,12 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -996,11 +1041,12 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
position_embeddings=None, object_queries=None,
query_position_embeddings=None, query_position_embeddings=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs,
): ):
r""" r"""
Args: Args:
...@@ -1024,10 +1070,11 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -1024,10 +1070,11 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**). - 0 for pixels that are padding (i.e. **masked**).
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Position embeddings that are added to the queries and keys in each cross-attention layer. Object queries that are added to the queries and keys in each cross-attention layer.
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -1037,6 +1084,22 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -1037,6 +1084,22 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -1099,7 +1162,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -1099,7 +1162,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
position_embeddings=position_embeddings, object_queries=object_queries,
query_position_embeddings=query_position_embeddings, query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
...@@ -1158,8 +1221,8 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1158,8 +1221,8 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = TableTransformerConvEncoder(config) backbone = TableTransformerConvEncoder(config)
position_embeddings = build_position_encoding(config) object_queries = build_position_encoding(config)
self.backbone = TableTransformerConvModel(backbone, position_embeddings) self.backbone = TableTransformerConvModel(backbone, object_queries)
# Create projection layer # Create projection layer
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
...@@ -1254,21 +1317,21 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1254,21 +1317,21 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
projected_feature_map = self.input_projection(feature_map) projected_feature_map = self.input_projection(feature_map)
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC # Third, flatten the feature map + object queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
# In other words, turn their shape into (batch_size, sequence_length, hidden_size) # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) object_queries = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
flattened_mask = mask.flatten(1) flattened_mask = mask.flatten(1)
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder # Fourth, sent flattened_features + flattened_mask + object queries through encoder
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size) # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
# flattened_mask is a Tensor of shape (batch_size, heigth*width) # flattened_mask is a Tensor of shape (batch_size, heigth*width)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=flattened_features, inputs_embeds=flattened_features,
attention_mask=flattened_mask, attention_mask=flattened_mask,
position_embeddings=position_embeddings, object_queries=object_queries,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1281,7 +1344,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1281,7 +1344,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) # Fifth, sent query embeddings + object queries through the decoder (which is conditioned on the encoder output)
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
queries = torch.zeros_like(query_position_embeddings) queries = torch.zeros_like(query_position_embeddings)
...@@ -1289,7 +1352,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1289,7 +1352,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs_embeds=queries, inputs_embeds=queries,
attention_mask=None, attention_mask=None,
position_embeddings=position_embeddings, object_queries=object_queries,
query_position_embeddings=query_position_embeddings, query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=flattened_mask, encoder_attention_mask=flattened_mask,
......
...@@ -606,7 +606,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase): ...@@ -606,7 +606,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
torch_device torch_device
) )
expected_number_of_segments = 5 expected_number_of_segments = 5
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096} expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
number_of_unique_segments = len(torch.unique(results["segmentation"])) number_of_unique_segments = len(torch.unique(results["segmentation"]))
self.assertTrue( self.assertTrue(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment