Commit ac3e86df authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove fusing attention, only blocks

parent 950851b3
from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM, MptAttention
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
......@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model)
fuser.fuse_attention()
fuser.fuse_block()
@staticmethod
......@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.block import MptBlock
from awq.modules.fused.attn import QuantAttentionFused
class MptFuser:
def __init__(self, model):
self.model = model
self.attention_modules: List[Tuple[str, MptAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptAttention)
]
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules()
if 'mptblock' in module.__class__.__name__.lower()
]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.n_heads,
qkv_layer,
qkv_layer.out_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens,
use_alibi=True
)
set_module_name(self.model, name, attn)
def fuse_block(self):
for name, module in self.mpt_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment