Commit ac7d8cfa authored by comfyanonymous's avatar comfyanonymous
Browse files

Allow attn_mask in attention_pytorch.

parent 1a4bd9e9
......@@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment