Commit ebbbc3a5 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create torch SDPA implementation

parent 890b6aa7
...@@ -98,6 +98,7 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -98,6 +98,7 @@ class QuantLlamaAttentionFused(nn.Module):
self.qkv_proj = qkv_layer self.qkv_proj = qkv_layer
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_sdpa_torch = False
# following fastertransformer definition # following fastertransformer definition
self.cache_v = ( self.cache_v = (
...@@ -116,7 +117,26 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -116,7 +117,26 @@ class QuantLlamaAttentionFused(nn.Module):
hidden_size // num_heads, hidden_size // num_heads,
max_position_embeddings * 2, max_position_embeddings * 2,
).to(dev) ).to(dev)
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value):
# faster prompt processing
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(query, key, value, is_causal=past_key_value is None)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward( def forward(
self, self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
...@@ -152,15 +172,18 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -152,15 +172,18 @@ class QuantLlamaAttentionFused(nn.Module):
values = xv values = xv
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2) if self.use_sdpa_torch:
keys = keys.transpose(1, 2) output = self._multi_query_attention_torch(xq, xk, xv, bsz, seqlen, True, past_key_value)
values = values.transpose(1, 2) else:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) xq = xq.transpose(1, 2)
if attention_mask is not None: keys = keys.transpose(1, 2)
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) values = values.transpose(1, 2)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) if attention_mask is not None:
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else: else:
xq = xq[:, 0, :, :] xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :] xk = xk[:, 0, :, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment