"tests/utils/test_logging.py" did not exist on "461ae86812f9d75762bbdae2ac5776f9a5d702ea"
Unverified Commit c31473ed authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Remove float64 cast for OwlVit and OwlV2 to support MPS device (#31071)

Remove float64
parent 936ab7ba
......@@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)
......
......@@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module):
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment