Commit 0c2c9fbd authored by comfyanonymous's avatar comfyanonymous
Browse files

Support attention mask in split attention.

parent 3ad0191b
......@@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None):
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment