"...llama.cpp/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "544b6739dde2a6b156b1673c72d94949c1940be7"
Commit 54f02854 authored by Casper Hansen's avatar Casper Hansen
Browse files

Falcon fused layers

parent 73c5e2bf
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
# fuser.fuse_attention()
# fuser.fuse_layernorm()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
return model.transformer.h return model.transformer.h
...@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
kwargs=module_kwargs, kwargs=module_kwargs,
)) ))
return layers return layers
\ No newline at end of file
import torch
import xformers
from torch.nn import LayerNorm
from typing import List, Tuple
from awq.utils.utils import set_module_name
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused
class FalconFuser:
def __init__(self, model):
self.model = model
self.attention_modules: List[Tuple[str, FalconAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, FalconAttention)
]
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LayerNorm)
]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.num_heads,
qkv_layer,
qkv_layer.dense,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
def fuse_layernorm(self):
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
for name, module in self.layernorm_modules:
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)
# copy weights and bias
with torch.no_grad():
norm.weight = module.weight
norm.bias = module.bias
set_module_name(self.model, name, norm)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment