Commit 7f8f9f16 authored by Casper Hansen's avatar Casper Hansen
Browse files

xk_reshape key

parent a8c9afd5
...@@ -114,7 +114,7 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -114,7 +114,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return query, key return query, key
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False): def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False, attention_shapes=None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_local_heads = num_heads
...@@ -124,7 +124,7 @@ class QuantAttentionFused(nn.Module): ...@@ -124,7 +124,7 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = 1 self.cache_batch_size = 1
self.attention_shapes = { self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,), "cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
...@@ -133,13 +133,14 @@ class QuantAttentionFused(nn.Module): ...@@ -133,13 +133,14 @@ class QuantAttentionFused(nn.Module):
"xq_slice": lambda xqkv: xqkv[:, :, 0], "xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1], "xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2], "xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
"xk_view": (self.n_local_heads, self.head_dim), "xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim), "xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim), "single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim), "single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim) "single_xv_view": (self.n_local_heads, self.head_dim)
} }
self.cache_v = ( self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
) )
...@@ -187,7 +188,7 @@ class QuantAttentionFused(nn.Module): ...@@ -187,7 +188,7 @@ class QuantAttentionFused(nn.Module):
values_store = xv.transpose(2, 1) values_store = xv.transpose(2, 1)
keys_store = ( keys_store = (
xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8) xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4) .permute(0, 2, 3, 1, 4)
.contiguous() .contiguous()
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment