Commit 4517b3f2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create Falcon block and model for fusing

parent e120c9b6
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM, FalconAttention from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
...@@ -7,13 +7,14 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,13 +7,14 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict): def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model) fuser = FalconFuser(model)
# fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
return model.transformer.h return model.transformer.h
@staticmethod @staticmethod
def get_act_for_scaling(module: FalconDecoderLayer): def get_act_for_scaling(module: OldFalconDecoderLayer):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
...@@ -26,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -26,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# Falcon 7B (older architecture) # Falcon 7B (older architecture)
...@@ -62,34 +63,46 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -62,34 +63,46 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch from awq.modules.fused.model import FalconModel
from torch.nn import LayerNorm from awq.modules.fused.block import FalconDecoderLayer
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.attn import QuantAttentionFused
class FalconFuser: class FalconFuser:
def __init__(self, model): def __init__(self, model: FalconForCausalLM):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, FalconAttention]] = [ def fuse_transformer(self):
(name, module) for name, module in self.model.named_modules() blocks = []
if isinstance(module, FalconAttention)
] module: OldFalconDecoderLayer
for module in self.model.transformer.h:
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [ if module.config.num_attention_heads == 71:
(name, module) for name, module in self.model.named_modules() input_layernorm = module.input_layernorm
if isinstance(module, LayerNorm) ln_attn = None
] ln_mlp = None
new_decoder_arch = False
def fuse_attention(self): else:
for name, qkv_layer in self.attention_modules: input_layernorm = None
attn = QuantAttentionFused( ln_attn = module.ln_attn
qkv_layer.hidden_size, ln_mlp = module.ln_mlp
qkv_layer.num_heads, new_decoder_arch = True
qkv_layer,
qkv_layer.dense, blocks.append(FalconDecoderLayer(
next(iter(qkv_layer.state_dict().values())).device, hidden_size=module.config.hidden_size,
self.model.config.max_new_tokens n_heads=module.config.num_attention_heads,
qkv_layer=module.self_attention.query_key_value,
o_proj=module.self_attention.dense,
mlp=module.mlp,
dev=next(iter(module.state_dict().values())).device,
max_seq_len=self.model.config.max_new_tokens,
input_layernorm=input_layernorm,
ln_attn=ln_attn,
ln_mlp=ln_mlp,
new_decoder_arch=new_decoder_arch
))
self.model.transformer = FalconModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.word_embeddings,
self.model.transformer.ln_f,
) )
\ No newline at end of file
set_module_name(self.model, name, attn)
\ No newline at end of file
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class MPTBlock(nn.Module): class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__() super().__init__()
...@@ -28,3 +27,49 @@ class MPTBlock(nn.Module): ...@@ -28,3 +27,49 @@ class MPTBlock(nn.Module):
h = hidden_states + attn_output h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h)) out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False).to(dev)
self.new_decoder_arch = new_decoder_arch
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h_attn = hidden_states + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
import torch import torch
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.block import MPTBlock from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
class MPTModel(nn.Module): class MPTModel(nn.Module):
...@@ -30,3 +30,31 @@ class MPTModel(nn.Module): ...@@ -30,3 +30,31 @@ class MPTModel(nn.Module):
h = self.norm_f(h) h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment