Unverified Commit 843fdf2e authored by Orr Zohar's avatar Orr Zohar Committed by GitHub
Browse files

Fixing class embedding selection in owl-vit (#23157)

fixing class embedding selection in owl-vit
parent bbfb9fc2
...@@ -1499,7 +1499,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1499,7 +1499,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
selected_inds = (ious[0] >= iou_threshold).nonzero() selected_inds = (ious[0] >= iou_threshold).nonzero()
if selected_inds.numel(): if selected_inds.numel():
selected_embeddings = class_embeds[i][selected_inds[0]] selected_embeddings = class_embeds[i][selected_inds.squeeze(1)]
mean_embeds = torch.mean(class_embeds[i], axis=0) mean_embeds = torch.mean(class_embeds[i], axis=0)
mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
best_box_ind = selected_inds[torch.argmin(mean_sim)] best_box_ind = selected_inds[torch.argmin(mean_sim)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment