Unverified Commit 8e7059a7 authored by qwopqwop200's avatar qwopqwop200 Committed by GitHub
Browse files

fix bug

parent e1884728
......@@ -186,7 +186,7 @@ class QuantAttentionFused(nn.Module):
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1 and have_single_query_attention:
if seqlen > 1 or not(have_single_query_attention):
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
......@@ -207,10 +207,13 @@ class QuantAttentionFused(nn.Module):
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
keys = xk
values = xv
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
......@@ -256,4 +259,4 @@ class QuantAttentionFused(nn.Module):
else:
self.start_pos = 0
return attn_output, attention_weight, past_key_value
\ No newline at end of file
return attn_output, attention_weight, past_key_value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment