Commit 620966e8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor Llama Quant RMSNorm

parent 2082197d
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_fused_mlp from awq.modules import make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -10,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -10,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
def fuse_layers(awq_model: BaseAWQForCausalLM): def fuse_layers(awq_model: BaseAWQForCausalLM):
fuser = LlamaFuser(awq_model) fuser = LlamaFuser(awq_model)
fuser.fuse_attention() fuser.fuse_attention()
make_quant_norm(awq_model)#fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
make_fused_mlp(awq_model)#fuser.fuse_mlp() make_fused_mlp(awq_model)#fuser.fuse_mlp()
@staticmethod @staticmethod
...@@ -70,6 +70,7 @@ import torch ...@@ -70,6 +70,7 @@ import torch
from typing import List, Tuple from typing import List, Tuple
from awq.quantize.qmodule import WQLinear from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
...@@ -125,7 +126,9 @@ class LlamaFuser: ...@@ -125,7 +126,9 @@ class LlamaFuser:
return qkv_layer return qkv_layer
def fuse_rmsnorm(self): def fuse_rmsnorm(self):
pass for name, module in self.rmsnorm_modules:
norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self): def fuse_mlp(self):
pass pass
import torch import torch
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import awq_inference_engine import awq_inference_engine
class FTLlamaRMSNorm(nn.Module): class FTLlamaRMSNorm(nn.Module):
...@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module): ...@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
output = torch.empty_like(x) output = torch.empty_like(x)
awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output return output
def make_quant_norm(model):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaRMSNorm):
continue
norm = FTLlamaRMSNorm(m.weight, m.variance_epsilon)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, norm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment