Commit 2d593b84 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix attention shapes. Add repeat interleave.

parent e3936a44
......@@ -121,6 +121,7 @@ class QuantAttentionFused(nn.Module):
self.hidden_size = hidden_size
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
......@@ -157,19 +158,17 @@ class QuantAttentionFused(nn.Module):
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_kv_heads],
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_kv_heads, self.head_dim),
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_kv_heads, self.head_dim),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
print(self.attention_shapes)
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
......@@ -234,6 +233,11 @@ class QuantAttentionFused(nn.Module):
keys = xk
values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment