Unverified Commit 1dbc1440 authored by Xiaoke Huang's avatar Xiaoke Huang Committed by GitHub
Browse files

Fix: repeat per sample for SAM image embeddings (#25074)

Repeat per sample for SAM image embeddings
parent cb8abee5
...@@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module): ...@@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module):
# Expand per-image data in batch direction to be per-point # Expand per-image data in batch direction to be per-point
image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1) image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer, image_positional_embedding are consumed # Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer( point_embedding, image_embeddings, attentions = self.transformer(
......
...@@ -517,8 +517,8 @@ class TFSamMaskDecoder(tf.keras.layers.Layer): ...@@ -517,8 +517,8 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
point_embeddings = tf.cast(tokens, self.iou_token.dtype) point_embeddings = tf.cast(tokens, self.iou_token.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
point_embedding, image_embeddings, attentions = self.transformer( point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings, point_embeddings=point_embeddings,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment