Commit ac3e86df authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove fusing attention, only blocks

parent 950851b3
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM, MptAttention from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
...@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict): def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_attention()
fuser.fuse_block() fuser.fuse_block()
@staticmethod @staticmethod
...@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused.block import MptBlock from awq.modules.fused.block import MptBlock
from awq.modules.fused.attn import QuantAttentionFused
class MptFuser: class MptFuser:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, MptAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptAttention)
]
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if 'mptblock' in module.__class__.__name__.lower() if 'mptblock' in module.__class__.__name__.lower()
] ]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.n_heads,
qkv_layer,
qkv_layer.out_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens,
use_alibi=True
)
set_module_name(self.model, name, attn)
def fuse_block(self): def fuse_block(self):
for name, module in self.mpt_blocks: for name, module in self.mpt_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment