Unverified Commit e619db24 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

mps cross-attention hack: don't crash on fp16 (#2258)

* mps cross-attention hack: don't crash on fp16

* Make conversion explicit.
parent 111228cb
...@@ -251,7 +251,7 @@ class CrossAttention(nn.Module): ...@@ -251,7 +251,7 @@ class CrossAttention(nn.Module):
# HACK: MPS: Does not support padding by greater than dimension of input tensor. # HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor. # Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, device=attention_mask.device) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.concat([attention_mask, padding], dim=2) attention_mask = torch.concat([attention_mask, padding], dim=2)
else: else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment