Commit d7badefc authored by Casper Hansen's avatar Casper Hansen
Browse files

Switch to torch SDPA

parent 54f02854
...@@ -154,6 +154,28 @@ class QuantAttentionFused(nn.Module): ...@@ -154,6 +154,28 @@ class QuantAttentionFused(nn.Module):
self.alibi_slopes = None self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value, attention_mask):
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(
query, key, value,
is_causal=past_key_value is None,
attn_mask=attention_mask)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward( def forward(
self, self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
...@@ -186,24 +208,12 @@ class QuantAttentionFused(nn.Module): ...@@ -186,24 +208,12 @@ class QuantAttentionFused(nn.Module):
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
keys = xk
values = xv
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
output = self._multi_query_attention_torch(
xq = xq.transpose(1, 2) xq, xk, xv,
keys = keys.transpose(1, 2) bsz, seqlen, True,
values = values.transpose(1, 2) past_key_value, attention_mask
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) )
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else: else:
xq = xq[:, 0, :, :] xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :] xk = xk[:, 0, :, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment