Commit 620966e8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor Llama Quant RMSNorm

parent 2082197d
from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_fused_mlp
from awq.modules import make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -10,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
def fuse_layers(awq_model: BaseAWQForCausalLM):
fuser = LlamaFuser(awq_model)
fuser.fuse_attention()
make_quant_norm(awq_model)#fuser.fuse_rmsnorm()
fuser.fuse_rmsnorm()
make_fused_mlp(awq_model)#fuser.fuse_mlp()
@staticmethod
......@@ -70,6 +70,7 @@ import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
......@@ -125,7 +126,9 @@ class LlamaFuser:
return qkv_layer
def fuse_rmsnorm(self):
pass
for name, module in self.rmsnorm_modules:
norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
pass
import torch
from torch import nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import awq_inference_engine
class FTLlamaRMSNorm(nn.Module):
......@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
output = torch.empty_like(x)
awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
def make_quant_norm(model):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaRMSNorm):
continue
norm = FTLlamaRMSNorm(m.weight, m.variance_epsilon)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, norm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment