Commit f3695d60 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fuse MPT

parent 5bd6fbc7
from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP, MptAttention, LayerNorm
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
......@@ -8,7 +8,8 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model)
fuser.fuse_mlp()
fuser.fuse_attention()
fuser.fuse_layernorm()
@staticmethod
def get_model_layers(model: MptForCausalLM):
......@@ -65,26 +66,56 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
import xformers
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantMPTMLP
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused
class MptFuser:
def __init__(self, model):
self.model = model
self.attention_modules: List[Tuple[str, MptAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptAttention)
]
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LayerNorm)
]
self.mlp_modules: List[Tuple[str, MptMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
]
def fuse_attention(self):
pass
for name, qkv_layer in self.attention_modules:
attn = QuantAttentionFused(
qkv_layer.hidden_size,
qkv_layer.n_heads,
qkv_layer,
qkv_layer.out_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens,
use_alibi=True
)
set_module_name(self.model, name, attn)
def fuse_layernorm(self):
pass
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
for name, module in self.layernorm_modules:
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)
# copy weights and bias
with torch.no_grad():
norm.weight = module.weight
norm.bias = module.bias
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
pass
\ No newline at end of file
import torch
import torch.nn as nn
import awq_inference_engine
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size
)
return self.down_proj(self.act(x))
class QuantLlamaMLP(nn.Module):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment