Commit ebbbc3a5 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create torch SDPA implementation

parent 890b6aa7
......@@ -98,6 +98,7 @@ class QuantLlamaAttentionFused(nn.Module):
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_sdpa_torch = False
# following fastertransformer definition
self.cache_v = (
......@@ -117,6 +118,25 @@ class QuantLlamaAttentionFused(nn.Module):
max_position_embeddings * 2,
).to(dev)
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value):
# faster prompt processing
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(query, key, value, is_causal=past_key_value is None)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
......@@ -152,6 +172,9 @@ class QuantLlamaAttentionFused(nn.Module):
values = xv
past_key_value = (xk, xv) if use_cache else None
if self.use_sdpa_torch:
output = self._multi_query_attention_torch(xq, xk, xv, bsz, seqlen, True, past_key_value)
else:
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment