Commit bc4f93c2 authored by Casper Hansen's avatar Casper Hansen
Browse files

xq view key

parent ba4da393
......@@ -134,6 +134,7 @@ class QuantAttentionFused(nn.Module):
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
"xq_view": (self.n_local_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
......@@ -171,12 +172,13 @@ class QuantAttentionFused(nn.Module):
bsz, seqlen, _ = hidden_states.shape
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1:
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment