"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f8b28e461a97162d70b48f44970c580a1dd6df73"
Unverified Commit 527ab894 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add PerSAM [bis] (#23659)

* Add PerSAM args

* Make attn_sim optional

* Rename to attention_similarity

* Add docstrigns

* Improve docstrings
parent aa30cd4f
...@@ -224,7 +224,7 @@ class SamAttention(nn.Module): ...@@ -224,7 +224,7 @@ class SamAttention(nn.Module):
hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections # Input projections
query = self.q_proj(query) query = self.q_proj(query)
key = self.k_proj(key) key = self.k_proj(key)
...@@ -242,6 +242,10 @@ class SamAttention(nn.Module): ...@@ -242,6 +242,10 @@ class SamAttention(nn.Module):
attn = attn / math.sqrt(c_per_head) attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1) attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
# Get output # Get output
out = attn @ value out = attn @ value
out = self._recombine_heads(out, point_batch_size) out = self._recombine_heads(out, point_batch_size)
...@@ -290,6 +294,7 @@ class SamTwoWayAttentionBlock(nn.Module): ...@@ -290,6 +294,7 @@ class SamTwoWayAttentionBlock(nn.Module):
keys: Tensor, keys: Tensor,
query_point_embedding: Tensor, query_point_embedding: Tensor,
key_point_embedding: Tensor, key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self attention block # Self attention block
...@@ -305,7 +310,9 @@ class SamTwoWayAttentionBlock(nn.Module): ...@@ -305,7 +310,9 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding query = queries + query_point_embedding
key = keys + key_point_embedding key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out queries = queries + attn_out
queries = self.layer_norm2(queries) queries = self.layer_norm2(queries)
...@@ -353,6 +360,8 @@ class SamTwoWayTransformer(nn.Module): ...@@ -353,6 +360,8 @@ class SamTwoWayTransformer(nn.Module):
point_embeddings: Tensor, point_embeddings: Tensor,
image_embeddings: Tensor, image_embeddings: Tensor,
image_positional_embeddings: Tensor, image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -377,11 +386,15 @@ class SamTwoWayTransformer(nn.Module): ...@@ -377,11 +386,15 @@ class SamTwoWayTransformer(nn.Module):
# Apply transformer blocks and final layernorm # Apply transformer blocks and final layernorm
for layer in self.layers: for layer in self.layers:
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer( queries, keys, attention_outputs = layer(
queries=queries, queries=queries,
keys=keys, keys=keys,
query_point_embedding=point_embeddings, query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings, key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -460,6 +473,8 @@ class SamMaskDecoder(nn.Module): ...@@ -460,6 +473,8 @@ class SamMaskDecoder(nn.Module):
dense_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor,
multimask_output: bool, multimask_output: bool,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Predict masks given image and prompt embeddings. Predict masks given image and prompt embeddings.
...@@ -500,6 +515,8 @@ class SamMaskDecoder(nn.Module): ...@@ -500,6 +515,8 @@ class SamMaskDecoder(nn.Module):
point_embeddings=point_embeddings, point_embeddings=point_embeddings,
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings, image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
iou_token_out = point_embedding[:, :, 0, :] iou_token_out = point_embedding[:, :, 0, :]
...@@ -576,8 +593,12 @@ class SamMaskEmbedding(nn.Module): ...@@ -576,8 +593,12 @@ class SamMaskEmbedding(nn.Module):
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps) self.layer_norm1 = SamLayerNorm(
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps) self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
def forward(self, masks): def forward(self, masks):
hidden_states = self.conv1(masks) hidden_states = self.conv1(masks)
...@@ -1146,6 +1167,12 @@ SAM_INPUTS_DOCSTRING = r""" ...@@ -1146,6 +1167,12 @@ SAM_INPUTS_DOCSTRING = r"""
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`. "best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
...@@ -1265,6 +1292,8 @@ class SamModel(SamPreTrainedModel): ...@@ -1265,6 +1292,8 @@ class SamModel(SamPreTrainedModel):
input_masks: Optional[torch.LongTensor] = None, input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True, multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict=None,
...@@ -1374,6 +1403,8 @@ class SamModel(SamPreTrainedModel): ...@@ -1374,6 +1403,8 @@ class SamModel(SamPreTrainedModel):
sparse_prompt_embeddings=sparse_embeddings, sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings, dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output, multimask_output=multimask_output,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment