"vscode:/vscode.git/clone" did not exist on "0f346a3296486deb79c63f778b9fc4d9107e4a23"
Commit ac770875 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create custom attention shape for Falcon 7B

parent 7f8f9f16
......@@ -33,10 +33,20 @@ class FalconDecoderLayer(nn.Module):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False).to(dev)
self.new_decoder_arch = new_decoder_arch
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(1, n_heads, max_seq_len, self.hidden_size // n_heads)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes
).to(dev)
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
......@@ -44,6 +54,26 @@ class FalconDecoderLayer(nn.Module):
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
def _get_attention_shapes(self, batch_size, n_heads, max_seq_len, head_dim):
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment