Unverified Commit bf0e0941 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix redundant normalization of OWL-ViT text embeddings (#19712)

parent 71ca7944
......@@ -1070,12 +1070,12 @@ class OwlViTModel(OwlViTPreTrainedModel):
image_embeds = self.visual_projection(image_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
image_embeds_norm = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_text = torch.matmul(text_embeds_norm, image_embeds_norm.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
......@@ -1085,6 +1085,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
if return_base_image_embeds:
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
else:
image_embeds = image_embeds_norm
text_embeds = text_embeds_norm
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment