Unverified Commit 92a403b2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `attention`] Fix fused attention generation with newest transformers version (#146)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent 3eda6562
...@@ -122,6 +122,18 @@ class QuantAttentionFused(nn.Module): ...@@ -122,6 +122,18 @@ class QuantAttentionFused(nn.Module):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
# Check if we are under transformers caching regime
has_past_key_value = kwargs is not None and "past_key_value" in kwargs and kwargs["past_key_value"] is not None
if has_past_key_value:
# In newest transformers version, when using caching the input hidden states do not consist of
# the last generated token only, but of the whole sequence - past-kvlength. We need to slice the last token
# and set `seqlen=1`
if seqlen > 1:
seqlen = 1
hidden_states = hidden_states[:, -1:]
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
raise RuntimeError( raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
...@@ -166,6 +178,7 @@ class QuantAttentionFused(nn.Module): ...@@ -166,6 +178,7 @@ class QuantAttentionFused(nn.Module):
# Only necessary to retrieve from cache when we are not processing context # Only necessary to retrieve from cache when we are not processing context
if seqlen == 1: if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk keys = xk
values = xv values = xv
...@@ -185,7 +198,6 @@ class QuantAttentionFused(nn.Module): ...@@ -185,7 +198,6 @@ class QuantAttentionFused(nn.Module):
# When seqlen is 1, there is nothing else to attend to # When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1: if attention_mask is not None and seqlen > 1:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
...@@ -215,4 +227,4 @@ class QuantAttentionFused(nn.Module): ...@@ -215,4 +227,4 @@ class QuantAttentionFused(nn.Module):
# past_key_value is replaced with cache_v, cache_k, returning empty data # past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])] past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value return attn_output, attention_weight, past_key_value
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment