"model_cards/vscode:/vscode.git/clone" did not exist on "3653d01f2af0389207f2239875a8ceae41bf0598"
Unverified Commit 75444551 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Make sam ONNX exportable (#22915)



* fix code not exportable

* fix

* Update src/transformers/models/sam/modeling_sam.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d03d8c72
......@@ -623,10 +623,16 @@ class SamPromptEncoder(nn.Module):
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == -10] = 0.0 # ignore points
# torch.where and expanding the labels tensor is required by the ONNX export
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
# This is required for the ONNX export. The dtype, device need to be explicitely
# specificed as otherwise torch.onnx.export interprets as double
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding[labels == 0] += self.point_embed[0].weight
point_embedding[labels == 1] += self.point_embed[1].weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment