Unverified Commit e44b878c authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

Fix float out of range in owlvit and owlv2 when using FP16 or lower precision (#31657)

parent 75a63198
......@@ -1276,7 +1276,7 @@ class Owlv2ClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds)
......
......@@ -1257,7 +1257,7 @@ class OwlViTClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment