Commit 3256ffec authored by Casper Hansen's avatar Casper Hansen
Browse files

Support kv_heads

parent 63a12504
...@@ -100,6 +100,7 @@ class LlamaFuser: ...@@ -100,6 +100,7 @@ class LlamaFuser:
attn = QuantAttentionFused( attn = QuantAttentionFused(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
next(iter(qkv_layer.state_dict().values())).device, next(iter(qkv_layer.state_dict().values())).device,
......
...@@ -115,35 +115,62 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -115,35 +115,62 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return query, key return query, key
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None): use_alibi=False, attention_shapes=None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_heads = n_heads
self.head_dim = self.hidden_size // num_heads self.n_kv_heads = n_kv_heads
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer self.qkv_proj = qkv_layer
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
if attention_shapes is not None:
self.attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
self.attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,), "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,), "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim), "xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0], "xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1], "xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2], "xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8), "xq_view": (self.n_heads, self.head_dim),
"xq_view": (self.n_local_heads, self.head_dim), "xk_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim), "xv_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim), "xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_local_heads, self.head_dim), "single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim), "single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim) "single_xv_view": (self.n_heads, self.head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_kv_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_kv_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_kv_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
} }
print(self.attention_shapes)
self.cache_v = ( self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
) )
...@@ -153,14 +180,14 @@ class QuantAttentionFused(nn.Module): ...@@ -153,14 +180,14 @@ class QuantAttentionFused(nn.Module):
) )
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len) alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev) self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev) self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads, hidden_size // n_heads,
max_seq_len * 2, max_seq_len * 2,
).to(dev) ).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
......
...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module): ...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1 self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev) self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=True
).to(dev)
self.norm_2 = norm_2 self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev) self.ffn = mpt_mlp.to(dev)
...@@ -30,16 +34,18 @@ class MPTBlock(nn.Module): ...@@ -30,16 +34,18 @@ class MPTBlock(nn.Module):
return out, None, past_key_value return out, None, past_key_value
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 8
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch) attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj, hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes attention_shapes=attention_shapes
).to(dev) ).to(dev)
...@@ -67,10 +73,10 @@ class FalconDecoderLayer(nn.Module): ...@@ -67,10 +73,10 @@ class FalconDecoderLayer(nn.Module):
"xq_slice": lambda xqkv: xqkv[:, :, :,0], "xq_slice": lambda xqkv: xqkv[:, :, :,0],
"xk_slice": lambda xqkv: xqkv[:, :, :,1], "xk_slice": lambda xqkv: xqkv[:, :, :,1],
"xv_slice": lambda xqkv: xqkv[:, :, :,2], "xv_slice": lambda xqkv: xqkv[:, :, :,2],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim), "xq_view": (1, head_dim),
"xk_view": (1, head_dim), "xk_view": (1, head_dim),
"xv_view": (1, head_dim), "xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim), "single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, 8, head_dim), "single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim) "single_xv_view": (1, 8, head_dim)
...@@ -85,10 +91,10 @@ class FalconDecoderLayer(nn.Module): ...@@ -85,10 +91,10 @@ class FalconDecoderLayer(nn.Module):
"xq_slice": lambda xqkv: xqkv[:, :, :-2], "xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]], "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]], "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim), "xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim), "xk_view": (1, head_dim),
"xv_view": (1, head_dim), "xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim), "single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim), "single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim) "single_xv_view": (1, head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment