Unverified Commit 9ade58f0 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

[ONNX] Sam fix (#23110)



* [WIP] Fix for the ONNX export

* Apply changes

* Remove commented code

* Resolve todo

* empty -> zeros

* fix slow tests

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 4baa34c1
......@@ -223,9 +223,7 @@ class SamAttention(nn.Module):
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(
batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head
)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
# Input projections
......@@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = max(1, sparse_prompt_embeddings.shape[1])
point_batch_size = sparse_prompt_embeddings.shape[1]
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
......@@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module):
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding[labels == 0] += self.point_embed[0].weight
point_embedding[labels == 1] += self.point_embed[1].weight
point_embedding = torch.where(
(labels == 0)[:, :, :, None],
point_embedding + self.point_embed[0].weight[None, None, :, :],
point_embedding,
)
point_embedding = torch.where(
(labels == 1)[:, :, :, None],
point_embedding + self.point_embed[1].weight[None, None, :, :],
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
......@@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module):
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = torch.empty((batch_size, point_batch_size, 0, self.hidden_size), device=target_device)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=2)
sparse_embeddings = point_embeddings
if input_boxes is not None:
batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes)
......@@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module):
)
if sparse_embeddings is None:
sparse_embeddings = torch.empty((batch_size, 0, 1, self.hidden_size), device=target_device)
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings
......@@ -742,17 +749,13 @@ class SamVisionAttention(nn.Module):
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
......@@ -865,8 +868,7 @@ class SamVisionLayer(nn.Module):
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
if pad_h > 0 or pad_w > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = hidden_states.reshape(
......@@ -902,8 +904,7 @@ class SamVisionLayer(nn.Module):
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
)
if pad_height > height or pad_width > width:
hidden_states = hidden_states[:, :height, :width, :].contiguous()
hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment