Commit 4517b3f2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create Falcon block and model for fusing

parent e120c9b6
from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM, FalconAttention
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"
......@@ -7,13 +7,14 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
# fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: FalconForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: FalconDecoderLayer):
def get_act_for_scaling(module: OldFalconDecoderLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
......@@ -26,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod
def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs):
layers = []
# Falcon 7B (older architecture)
......@@ -62,34 +63,46 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from torch.nn import LayerNorm
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.model import FalconModel
from awq.modules.fused.block import FalconDecoderLayer
class FalconFuser:
def __init__(self, model):
def __init__(self, model: FalconForCausalLM):
self.model = model
self.attention_modules: List[Tuple[str, FalconAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, FalconAttention)
]
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LayerNorm)
]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.num_heads,
qkv_layer,
qkv_layer.dense,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []
module: OldFalconDecoderLayer
for module in self.model.transformer.h:
if module.config.num_attention_heads == 71:
input_layernorm = module.input_layernorm
ln_attn = None
ln_mlp = None
new_decoder_arch = False
else:
input_layernorm = None
ln_attn = module.ln_attn
ln_mlp = module.ln_mlp
new_decoder_arch = True
blocks.append(FalconDecoderLayer(
hidden_size=module.config.hidden_size,
n_heads=module.config.num_attention_heads,
qkv_layer=module.self_attention.query_key_value,
o_proj=module.self_attention.dense,
mlp=module.mlp,
dev=next(iter(module.state_dict().values())).device,
max_seq_len=self.model.config.max_new_tokens,
input_layernorm=input_layernorm,
ln_attn=ln_attn,
ln_mlp=ln_mlp,
new_decoder_arch=new_decoder_arch
))
self.model.transformer = FalconModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.word_embeddings,
self.model.transformer.ln_f,
)
\ No newline at end of file
set_module_name(self.model, name, attn)
\ No newline at end of file
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
......@@ -28,3 +27,49 @@ class MPTBlock(nn.Module):
h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False).to(dev)
self.new_decoder_arch = new_decoder_arch
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h_attn = hidden_states + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
import torch
import torch.nn as nn
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
class MPTModel(nn.Module):
......@@ -30,3 +30,31 @@ class MPTModel(nn.Module):
h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment