"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "09e841490c3ae89a99a2d6289cde66656cb16dac"
Unverified Commit e44b878c authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

Fix float out of range in owlvit and owlv2 when using FP16 or lower precision (#31657)

parent 75a63198
...@@ -1276,7 +1276,7 @@ class Owlv2ClassPredictionHead(nn.Module): ...@@ -1276,7 +1276,7 @@ class Owlv2ClassPredictionHead(nn.Module):
if query_mask.ndim > 1: if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2) query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
pred_logits = pred_logits.to(torch.float32) pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds) return (pred_logits, image_class_embeds)
......
...@@ -1257,7 +1257,7 @@ class OwlViTClassPredictionHead(nn.Module): ...@@ -1257,7 +1257,7 @@ class OwlViTClassPredictionHead(nn.Module):
if query_mask.ndim > 1: if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2) query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
pred_logits = pred_logits.to(torch.float32) pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds) return (pred_logits, image_class_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment