Commit d973a0e0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use default attention shapes

parent 2d593b84
...@@ -41,7 +41,11 @@ class FalconDecoderLayer(nn.Module): ...@@ -41,7 +41,11 @@ class FalconDecoderLayer(nn.Module):
self.n_kv_heads = 8 self.n_kv_heads = 8
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
...@@ -58,30 +62,9 @@ class FalconDecoderLayer(nn.Module): ...@@ -58,30 +62,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch:
kv_heads = 8
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :,0],
"xk_slice": lambda xqkv: xqkv[:, :, :,1],
"xv_slice": lambda xqkv: xqkv[:, :, :,2],
"xq_view": (1, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = { self.attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,), "cache_v": (batch_size, 1, max_seq_len, head_dim,),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment