"...text-generation-inference.git" did not exist on "eefea5ee3184179b2f440238e403d26e34a17491"
Commit ac3e86df authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove fusing attention, only blocks

parent 950851b3
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM, MptAttention from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
...@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict): def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_attention()
fuser.fuse_block() fuser.fuse_block()
@staticmethod @staticmethod
...@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused.block import MptBlock from awq.modules.fused.block import MptBlock
from awq.modules.fused.attn import QuantAttentionFused
class MptFuser: class MptFuser:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, MptAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptAttention)
]
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if 'mptblock' in module.__class__.__name__.lower() if 'mptblock' in module.__class__.__name__.lower()
] ]
def fuse_attention(self):
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.n_heads,
qkv_layer,
qkv_layer.out_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens,
use_alibi=True
)
set_module_name(self.model, name, attn)
def fuse_block(self): def fuse_block(self):
for name, module in self.mpt_blocks: for name, module in self.mpt_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment