Unverified Commit 10c2ac7b authored by Pasquale De Marinis's avatar Pasquale De Marinis Committed by GitHub
Browse files

Fixed OwlViTModel inplace operations (#24529)

* fixed OwlViTModel inplace operations

* fixed operands order in owlvit
parent 66954ea2
......@@ -1294,8 +1294,8 @@ class OwlViTClassPredictionHead(nn.Module):
return (pred_logits, image_class_embeds)
# Normalize image and text features
image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
# Get class predictions
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment