Commit 1df0136e authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor MPT Quant MLP

parent ded3ea71
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(awq_model): def fuse_layers(model: MptForCausalLM):
make_fused_mlp(awq_model) fuser = MptFuser(model)
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module: MptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
) )
@staticmethod @staticmethod
def move_embed(model, device): def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -62,4 +63,28 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -62,4 +63,28 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
inp=input_feat['ffn.down_proj'] inp=input_feat['ffn.down_proj']
)) ))
return layers return layers
\ No newline at end of file
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP
class MptFuser:
def __init__(self, model):
self.model = model
self.mlp_modules: List[Tuple[str, MptMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
]
def fuse_attention(self):
pass
def fuse_layernorm(self):
pass
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine import awq_inference_engine
import torch.nn.functional as F
class QuantMPTMLP(nn.Module): class QuantMPTMLP(nn.Module):
def __init__( def __init__(
...@@ -67,23 +63,3 @@ class QuantLlamaMLP(nn.Module): ...@@ -67,23 +63,3 @@ class QuantLlamaMLP(nn.Module):
c = gate_output * up_output c = gate_output * up_output
c = c.reshape(out_shape) c = c.reshape(out_shape)
return c return c
def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"):
make_fused_mlp.called = True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment