Unverified Commit a5e8b048 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #60 from casper-hansen/kv_heads

Support kv_heads
parents bf76e108 a024e893
...@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict): def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model) fuser = FalconFuser(model)
fuser.fuse_transformer()
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
if model.config.num_attention_heads == 71:
fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
......
...@@ -100,6 +100,7 @@ class LlamaFuser: ...@@ -100,6 +100,7 @@ class LlamaFuser:
attn = QuantAttentionFused( attn = QuantAttentionFused(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
next(iter(qkv_layer.state_dict().values())).device, next(iter(qkv_layer.state_dict().values())).device,
......
...@@ -61,34 +61,60 @@ def build_alibi_bias( ...@@ -61,34 +61,60 @@ def build_alibi_bias(
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None): use_alibi=False, attention_shapes=None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_heads = n_heads
self.head_dim = self.hidden_size // num_heads self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer self.qkv_proj = qkv_layer
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition if attention_shapes is not None:
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,), self.attention_shapes = attention_shapes
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,), elif self.n_kv_heads == 0:
"xqkv_view": (-1, self.n_local_heads, self.head_dim), self.attention_shapes = {
"xq_slice": lambda xqkv: xqkv[:, :, 0], # following fastertransformer definition
"xk_slice": lambda xqkv: xqkv[:, :, 1], "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
"xv_slice": lambda xqkv: xqkv[:, :, 2], # 8: pack 8 fp16 in FT, if fp32 then use 4
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8), "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xq_view": (self.n_local_heads, self.head_dim), "xqkv_view": (-1, self.n_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim), "xq_slice": lambda xqkv: xqkv[:, :, 0],
"xv_view": (self.n_local_heads, self.head_dim), "xk_slice": lambda xqkv: xqkv[:, :, 1],
"single_xq_view": (self.n_local_heads, self.head_dim), "xv_slice": lambda xqkv: xqkv[:, :, 2],
"single_xk_view": (self.n_local_heads, self.head_dim), "xq_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim) "xk_view": (self.n_heads, self.head_dim),
} "xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
self.cache_v = ( self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
...@@ -99,14 +125,14 @@ class QuantAttentionFused(nn.Module): ...@@ -99,14 +125,14 @@ class QuantAttentionFused(nn.Module):
) )
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len) alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev) self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev) self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads, hidden_size // n_heads,
max_seq_len * 2, max_seq_len * 2,
).to(dev) ).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
...@@ -153,6 +179,11 @@ class QuantAttentionFused(nn.Module): ...@@ -153,6 +179,11 @@ class QuantAttentionFused(nn.Module):
keys = xk keys = xk
values = xv values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
......
...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module): ...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1 self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev) self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=True
).to(dev)
self.norm_2 = norm_2 self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev) self.ffn = mpt_mlp.to(dev)
...@@ -30,16 +34,22 @@ class MPTBlock(nn.Module): ...@@ -30,16 +34,22 @@ class MPTBlock(nn.Module):
return out, None, past_key_value return out, None, past_key_value
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 8
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj, hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes attention_shapes=attention_shapes
).to(dev) ).to(dev)
...@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module): ...@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch: self.attention_shapes = {
kv_heads = 8 # following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
self.attention_shapes = { # 8: pack 8 fp16 in FT, if fp32 then use 4
# following fastertransformer definition "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,), "xqkv_view": (n_heads+2, head_dim),
# 8: pack 8 fp16 in FT, if fp32 then use 4 "xq_slice": lambda xqkv: xqkv[:, :, :-2],
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,), "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim), "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xq_slice": lambda xqkv: xqkv[:, :, :,0], "xq_view": (n_heads, head_dim),
"xk_slice": lambda xqkv: xqkv[:, :, :,1], "xk_view": (1, head_dim),
"xv_slice": lambda xqkv: xqkv[:, :, :,2], "xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8), "xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim), "single_xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim), "single_xk_view": (1, head_dim),
"xv_view": (1, head_dim), "single_xv_view": (1, head_dim)
"single_xq_view": (n_heads, head_dim), }
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes return self.attention_shapes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment