"...resources/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "41d96cd89acbe5fa77a9ac87516ff1a4c9adb384"
Unverified Commit 9ade58f0 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

[ONNX] Sam fix (#23110)



* [WIP] Fix for the ONNX export

* Apply changes

* Remove commented code

* Resolve todo

* empty -> zeros

* fix slow tests

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 4baa34c1
...@@ -223,9 +223,7 @@ class SamAttention(nn.Module): ...@@ -223,9 +223,7 @@ class SamAttention(nn.Module):
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape( return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head
)
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
# Input projections # Input projections
...@@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module): ...@@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module):
Whether or not to return the attentions tensors of all attention layers. Whether or not to return the attentions tensors of all attention layers.
""" """
batch_size, num_channels, height, width = image_embeddings.shape batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = max(1, sparse_prompt_embeddings.shape[1]) point_batch_size = sparse_prompt_embeddings.shape[1]
# Concatenate output tokens # Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
...@@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module): ...@@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module):
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
) )
point_embedding[labels == 0] += self.point_embed[0].weight point_embedding = torch.where(
point_embedding[labels == 1] += self.point_embed[1].weight (labels == 0)[:, :, :, None],
point_embedding + self.point_embed[0].weight[None, None, :, :],
point_embedding,
)
point_embedding = torch.where(
(labels == 1)[:, :, :, None],
point_embedding + self.point_embed[1].weight[None, None, :, :],
point_embedding,
)
return point_embedding return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
...@@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module): ...@@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module):
if input_labels is None: if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.") raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = torch.empty((batch_size, point_batch_size, 0, self.hidden_size), device=target_device) sparse_embeddings = point_embeddings
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=2)
if input_boxes is not None: if input_boxes is not None:
batch_size = input_boxes.shape[0] batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes) box_embeddings = self._embed_boxes(input_boxes)
...@@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module): ...@@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module):
) )
if sparse_embeddings is None: if sparse_embeddings is None:
sparse_embeddings = torch.empty((batch_size, 0, 1, self.hidden_size), device=target_device) sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings return sparse_embeddings, dense_embeddings
...@@ -742,8 +749,6 @@ class SamVisionAttention(nn.Module): ...@@ -742,8 +749,6 @@ class SamVisionAttention(nn.Module):
Extracted positional embeddings according to relative positions. Extracted positional embeddings according to relative positions.
""" """
max_rel_dist = int(2 * max(q_size, k_size) - 1) max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos. # Interpolate rel pos.
rel_pos_resized = F.interpolate( rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
...@@ -751,8 +756,6 @@ class SamVisionAttention(nn.Module): ...@@ -751,8 +756,6 @@ class SamVisionAttention(nn.Module):
mode="linear", mode="linear",
) )
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different. # Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
...@@ -865,7 +868,6 @@ class SamVisionLayer(nn.Module): ...@@ -865,7 +868,6 @@ class SamVisionLayer(nn.Module):
pad_h = (window_size - height % window_size) % window_size pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size pad_w = (window_size - width % window_size) % window_size
if pad_h > 0 or pad_w > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
pad_height, pad_width = height + pad_h, width + pad_w pad_height, pad_width = height + pad_h, width + pad_w
...@@ -902,7 +904,6 @@ class SamVisionLayer(nn.Module): ...@@ -902,7 +904,6 @@ class SamVisionLayer(nn.Module):
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
) )
if pad_height > height or pad_width > width:
hidden_states = hidden_states[:, :height, :width, :].contiguous() hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment