Commit 7631add1 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix attention, support alibi

parent 90f54dbb
...@@ -154,28 +154,6 @@ class QuantAttentionFused(nn.Module): ...@@ -154,28 +154,6 @@ class QuantAttentionFused(nn.Module):
self.alibi_slopes = None self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value, attention_mask):
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(
query, key, value,
is_causal=past_key_value is None,
attn_mask=attention_mask)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward( def forward(
self, self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
...@@ -208,18 +186,30 @@ class QuantAttentionFused(nn.Module): ...@@ -208,18 +186,30 @@ class QuantAttentionFused(nn.Module):
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
keys = xk
values = xv
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
output = self._multi_query_attention_torch(
xq, xk, xv, xq = xq.transpose(1, 2)
bsz, seqlen, True, keys = keys.transpose(1, 2)
past_key_value, attention_mask values = values.transpose(1, 2)
) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else: else:
xq = xq[:, 0, :, :] xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :] xk = xk[:, 0, :, :]
xv = xv[:, 0, :, :] xv = xv[:, 0, :, :]
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
output = awq_inference_engine.single_query_attention( attention_weight = awq_inference_engine.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
xv, # value xv, # value
...@@ -232,13 +222,13 @@ class QuantAttentionFused(nn.Module): ...@@ -232,13 +222,13 @@ class QuantAttentionFused(nn.Module):
10000, # rotary embedding base 10000, # rotary embedding base
self.is_neox, # is neox self.is_neox, # is neox
) )
output = output.reshape(bsz, 1, -1) attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(output) attn_output = self.o_proj(attention_weight)
if use_cache: if use_cache:
self.start_pos += seqlen self.start_pos += seqlen
else: else:
self.start_pos = 0 self.start_pos = 0
return attn_output, None, past_key_value return attn_output, attention_weight, past_key_value
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment