Commit 950851b3 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fuse MPT block

parent d7badefc
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP, MptAttention, LayerNorm from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM, MptAttention
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
...@@ -9,14 +9,14 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -9,14 +9,14 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
def fuse_layers(model: MptForCausalLM, quant_config:dict): def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_layernorm() fuser.fuse_block()
@staticmethod @staticmethod
def get_model_layers(model: MptForCausalLM): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module: MptBlock): def get_act_for_scaling(module: OldMptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
...@@ -30,7 +30,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -30,7 +30,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -66,11 +66,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -66,11 +66,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch
import xformers
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from xformers.triton.layer_norm import FusedLayerNorm from awq.modules.fused.block import MptBlock
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class MptFuser: class MptFuser:
...@@ -82,14 +80,9 @@ class MptFuser: ...@@ -82,14 +80,9 @@ class MptFuser:
if isinstance(module, MptAttention) if isinstance(module, MptAttention)
] ]
self.layernorm_modules: List[Tuple[str, LayerNorm]] = [ self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if isinstance(module, LayerNorm) if 'mptblock' in module.__class__.__name__.lower()
]
self.mlp_modules: List[Tuple[str, MptMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
] ]
def fuse_attention(self): def fuse_attention(self):
...@@ -105,17 +98,14 @@ class MptFuser: ...@@ -105,17 +98,14 @@ class MptFuser:
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
def fuse_layernorm(self): def fuse_block(self):
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True for name, module in self.mpt_blocks:
for name, module in self.layernorm_modules: block = MptBlock(
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device) self.model.config.d_model,
self.model.config.n_heads,
# copy weights and bias module.attn.Wqkv,
with torch.no_grad(): module.attn.out_proj,
norm.weight = module.weight module.ffn
norm.bias = module.bias )
set_module_name(self.model, name, norm)
def fuse_mlp(self): set_module_name(self.model, name, block)
pass \ No newline at end of file
\ No newline at end of file
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MptBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev="cuda:0", max_seq_len=8096, use_alibi=True).to("cuda:0")
self.ffn = mpt_mlp.to("cuda:0")
self.norm_1 = nn.LayerNorm(hidden_size, eps=1e-6).half().to("cuda:0")
self.norm_2 = nn.LayerNorm(hidden_size, eps=1e-6).half().to("cuda:0")
def forward(
self, hidden_states, past_key_value, attn_bias, attention_mask, is_causal
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment