Commit 54f02854 authored by Casper Hansen's avatar Casper Hansen
Browse files

Falcon fused layers

parent 73c5e2bf
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
# fuser.fuse_attention()
# fuser.fuse_layernorm()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
return model.transformer.h return model.transformer.h
...@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
kwargs=module_kwargs, kwargs=module_kwargs,
)) ))
return layers return layers
\ No newline at end of file
import torch
import xformers
from torch.nn import LayerNorm
from typing import List, Tuple
from awq.utils.utils import set_module_name
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused
class FalconFuser:
def __init__(self, model):
self.model = model
self.attention_modules: List[Tuple[str, FalconAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, FalconAttention)
]
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LayerNorm)
]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.num_heads,
qkv_layer,
qkv_layer.dense,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
def fuse_layernorm(self):
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
for name, module in self.layernorm_modules:
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)
# copy weights and bias
with torch.no_grad():
norm.weight = module.weight
norm.bias = module.bias
set_module_name(self.model, name, norm)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment