Commit ded3ea71 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor Llama Quant MLP

parent 620966e8
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -11,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -11,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
fuser = LlamaFuser(awq_model) fuser = LlamaFuser(awq_model)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
make_fused_mlp(awq_model)#fuser.fuse_mlp() fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: LlamaForCausalLM):
...@@ -70,9 +69,10 @@ import torch ...@@ -70,9 +69,10 @@ import torch
from typing import List, Tuple from typing import List, Tuple
from awq.quantize.qmodule import WQLinear from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
def __init__(self, awq_model: BaseAWQForCausalLM): def __init__(self, awq_model: BaseAWQForCausalLM):
...@@ -89,6 +89,11 @@ class LlamaFuser: ...@@ -89,6 +89,11 @@ class LlamaFuser:
if isinstance(module, LlamaRMSNorm) if isinstance(module, LlamaRMSNorm)
] ]
self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaMLP)
]
def fuse_attention(self): def fuse_attention(self):
for name, module in self.attention_modules: for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module) qkv_layer: WQLinear = self._fuse_qkv(module)
...@@ -131,4 +136,6 @@ class LlamaFuser: ...@@ -131,4 +136,6 @@ class LlamaFuser:
set_module_name(self.model, name, norm) set_module_name(self.model, name, norm)
def fuse_mlp(self): def fuse_mlp(self):
pass for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
...@@ -75,9 +75,7 @@ def make_fused_mlp(m, parent_name=''): ...@@ -75,9 +75,7 @@ def make_fused_mlp(m, parent_name=''):
""" """
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
""" """
if isinstance(m, LlamaMLP): if "mptmlp" in str(m.__class__).lower():
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj) return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children(): for name, child in m.named_children():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment