Commit ebbbc3a5 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create torch SDPA implementation

parent 890b6aa7
......@@ -98,6 +98,7 @@ class QuantLlamaAttentionFused(nn.Module):
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_sdpa_torch = False
# following fastertransformer definition
self.cache_v = (
......@@ -116,7 +117,26 @@ class QuantLlamaAttentionFused(nn.Module):
hidden_size // num_heads,
max_position_embeddings * 2,
).to(dev)
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value):
# faster prompt processing
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(query, key, value, is_causal=past_key_value is None)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
......@@ -152,15 +172,18 @@ class QuantLlamaAttentionFused(nn.Module):
values = xv
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
if self.use_sdpa_torch:
output = self._multi_query_attention_torch(xq, xk, xv, bsz, seqlen, True, past_key_value)
else:
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment