Unverified Commit c31473ed authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Remove float64 cast for OwlVit and OwlV2 to support MPS device (#31071)

Remove float64
parent 936ab7ba
...@@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module): ...@@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module):
if query_mask.ndim > 1: if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2) query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32) pred_logits = pred_logits.to(torch.float32)
......
...@@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module): ...@@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module):
if query_mask.ndim > 1: if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2) query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32) pred_logits = pred_logits.to(torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment