Unverified Commit 1b54b9f9 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #96 from casper-hansen/fix_attention_mask

Only apply attention mask if seqlen is greater than 1
parents 0baf5e18 e94b7f40
...@@ -176,7 +176,8 @@ class QuantAttentionFused(nn.Module): ...@@ -176,7 +176,8 @@ class QuantAttentionFused(nn.Module):
if self.use_alibi: if self.use_alibi:
scores = self.alibi.forward(scores, seqlen) scores = self.alibi.forward(scores, seqlen)
if attention_mask is not None: # When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment