Unverified Commit a5e8b048 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #60 from casper-hansen/kv_heads

Support kv_heads
parents bf76e108 a024e893
......@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
fuser.fuse_transformer()
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
if model.config.num_attention_heads == 71:
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: FalconForCausalLM):
......
......@@ -100,6 +100,7 @@ class LlamaFuser:
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
......
......@@ -61,34 +61,60 @@ def build_alibi_bias(
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len,
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = self.hidden_size // num_heads
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
"xq_view": (self.n_local_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim)
}
if attention_shapes is not None:
self.attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
......@@ -99,14 +125,14 @@ class QuantAttentionFused(nn.Module):
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len)
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads,
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
......@@ -153,6 +179,11 @@ class QuantAttentionFused(nn.Module):
keys = xk
values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
......
......@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size
self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev)
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=True
).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
......@@ -30,16 +34,22 @@ class MPTBlock(nn.Module):
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 8
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj,
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes
).to(dev)
......@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch):
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch:
kv_heads = 8
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :,0],
"xk_slice": lambda xqkv: xqkv[:, :, :,1],
"xv_slice": lambda xqkv: xqkv[:, :, :,2],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment