Unverified Commit ac224dee authored by Matt's avatar Matt Committed by GitHub
Browse files

TF SAM shape flexibility fixes (#23842)

SAM shape flexibility fixes for compilation
parent af45ec0a
...@@ -226,7 +226,8 @@ class TFSamAttention(tf.keras.layers.Layer): ...@@ -226,7 +226,8 @@ class TFSamAttention(tf.keras.layers.Layer):
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape( return tf.reshape(
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
) )
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
...@@ -509,7 +510,7 @@ class TFSamMaskDecoder(tf.keras.layers.Layer): ...@@ -509,7 +510,7 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA. # it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if sparse_prompt_embeddings.shape[1] != 0: if shape_list(sparse_prompt_embeddings)[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else: else:
tokens = output_tokens tokens = output_tokens
...@@ -695,8 +696,8 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): ...@@ -695,8 +696,8 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
"""Embeds point prompts.""" """Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel points = points + 0.5 # Shift to center of pixel
if pad: if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1) target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype) padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2) points = tf.concat([points, padding_point], axis=2)
...@@ -722,12 +723,12 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): ...@@ -722,12 +723,12 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts.""" """Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2] batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size) input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape) corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where( corner_embedding += tf.where(
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0], self.point_embed[2][0],
self.point_embed[3][0], self.point_embed[3][0],
) )
...@@ -754,7 +755,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): ...@@ -754,7 +755,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
""" """
sparse_embeddings = None sparse_embeddings = None
if input_points is not None: if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2] batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None: if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.") raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
...@@ -763,7 +764,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer): ...@@ -763,7 +764,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
) )
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None: if input_boxes is not None:
batch_size = input_boxes.shape[0] batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes) box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None: if sparse_embeddings is None:
sparse_embeddings = box_embeddings sparse_embeddings = box_embeddings
...@@ -1376,8 +1377,8 @@ class TFSamModel(TFSamPreTrainedModel): ...@@ -1376,8 +1377,8 @@ class TFSamModel(TFSamPreTrainedModel):
" got {}.".format(input_boxes.shape), " got {}.".format(input_boxes.shape),
) )
if input_points is not None and input_boxes is not None: if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1] point_batch_size = shape_list(input_points)[1]
box_batch_size = input_boxes.shape[1] box_batch_size = shape_list(input_boxes)[1]
if point_batch_size != box_batch_size: if point_batch_size != box_batch_size:
raise ValueError( raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format( "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment