Commit dd41a223 authored by Casper Hansen's avatar Casper Hansen
Browse files

Update MPTBlock, fuse with MPTModel

parent 7631add1
......@@ -8,7 +8,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model)
fuser.fuse_block()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: MptForCausalLM):
......@@ -67,10 +67,11 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.block import MptBlock
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
class MptFuser:
def __init__(self, model):
def __init__(self, model: MptForCausalLM):
self.model = model
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
......@@ -78,15 +79,26 @@ class MptFuser:
if 'mptblock' in module.__class__.__name__.lower()
]
def fuse_block(self):
for name, module in self.mpt_blocks:
block = MptBlock(
def fuse_transformer(self):
blocks = []
module: OldMptBlock
for module in self.model.transformer.blocks:
blocks.append(MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
next(iter(module.state_dict().values())).device
)
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_new_tokens
))
set_module_name(self.model, name, block)
\ No newline at end of file
self.model.transformer = MPTModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.wte,
self.model.transformer.norm_f,
)
\ No newline at end of file
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MptBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, dev):
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev="cuda:0", max_seq_len=8096, use_alibi=True)
self.ffn = mpt_mlp
self.norm_1 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
self.norm_2 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
def forward(
self, hidden_states, past_key_value, attn_bias, attention_mask, is_causal
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment