"vscode:/vscode.git/clone" did not exist on "ea0be26b88778b1033d4a176be68bcdd008ff934"
Commit ded3ea71 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor Llama Quant MLP

parent 620966e8
from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -11,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
fuser = LlamaFuser(awq_model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
make_fused_mlp(awq_model)#fuser.fuse_mlp()
fuser.fuse_mlp()
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
......@@ -70,9 +69,10 @@ import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, awq_model: BaseAWQForCausalLM):
......@@ -88,6 +88,11 @@ class LlamaFuser:
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm)
]
self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaMLP)
]
def fuse_attention(self):
for name, module in self.attention_modules:
......@@ -131,4 +136,6 @@ class LlamaFuser:
set_module_name(self.model, name, norm)
def fuse_mlp(self):
pass
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
......@@ -75,9 +75,7 @@ def make_fused_mlp(m, parent_name=''):
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
if "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment