"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "90eac14f720cf66ca1e28f1cc4af32df44806bc7"
Commit 1df0136e authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor MPT Quant MLP

parent ded3ea71
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(awq_model): def fuse_layers(model: MptForCausalLM):
make_fused_mlp(awq_model) fuser = MptFuser(model)
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module: MptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
) )
@staticmethod @staticmethod
def move_embed(model, device): def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -62,4 +63,28 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -62,4 +63,28 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
inp=input_feat['ffn.down_proj'] inp=input_feat['ffn.down_proj']
)) ))
return layers return layers
\ No newline at end of file
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP
class MptFuser:
def __init__(self, model):
self.model = model
self.mlp_modules: List[Tuple[str, MptMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
]
def fuse_attention(self):
pass
def fuse_layernorm(self):
pass
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine import awq_inference_engine
import torch.nn.functional as F
class QuantMPTMLP(nn.Module): class QuantMPTMLP(nn.Module):
def __init__( def __init__(
...@@ -67,23 +63,3 @@ class QuantLlamaMLP(nn.Module): ...@@ -67,23 +63,3 @@ class QuantLlamaMLP(nn.Module):
c = gate_output * up_output c = gate_output * up_output
c = c.reshape(out_shape) c = c.reshape(out_shape)
return c return c
def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"):
make_fused_mlp.called = True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment