"sgl-kernel/vscode:/vscode.git/clone" did not exist on "dd949ace23d6d112d9e3e1e2020deb3278cc8b3c"
mpt.py 2.53 KB
Newer Older
Casper's avatar
Casper committed
1
from .base import BaseAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
2
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
Casper's avatar
Casper committed
3
4
5

class MptAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MPTBlock"
6
    max_new_tokens_key = "max_seq_len"
Casper's avatar
Casper committed
7

8
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
9
10
11
    def fuse_layers(model: MptForCausalLM):
        fuser = MptFuser(model)
        fuser.fuse_mlp()
12

13
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
14
    def get_model_layers(model: MptForCausalLM):
Casper's avatar
Casper committed
15
16
        return model.transformer.blocks
    
17
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
18
    def get_act_for_scaling(module: MptBlock):
19
20
21
22
23
24
25
26
        return dict(
            is_scalable=True,
            scale_name="ffn.act",
            scale_layer=module.ffn.act,
            scale_shape=module.ffn.up_proj.out_features
        )
    
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
27
    def move_embed(model: MptForCausalLM, device: str):
28
29
30
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.emb_drop = model.transformer.emb_drop.to(device)
    
31
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
32
    def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
Casper's avatar
Casper committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.norm_1,
            layers=[module.attn.Wqkv],
            inp=input_feat['attn.Wqkv'],
            module2inspect=module.attn,
            kwargs=module_kwargs
        ))

        # attention output
        layers.append(dict(
            prev_op=module.attn.Wqkv,
            layers=[module.attn.out_proj],
            inp=input_feat['attn.out_proj']
        ))

        # linear 1
        layers.append(dict(
Casper Hansen's avatar
Casper Hansen committed
53
            prev_op=module.norm_2,
Casper's avatar
Casper committed
54
55
56
57
58
59
60
61
62
63
64
65
            layers=[module.ffn.up_proj],
            inp=input_feat['ffn.up_proj'],
            module2inspect=module.ffn
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.ffn.act,
            layers=[module.ffn.down_proj],
            inp=input_feat['ffn.down_proj']
        ))

Casper Hansen's avatar
Casper Hansen committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        return layers

from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP

class MptFuser:
    def __init__(self, model):
        self.model = model

        self.mlp_modules: List[Tuple[str, MptMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MptMLP)
        ]
    
    def fuse_attention(self):
        pass

    def fuse_layernorm(self):
        pass

    def fuse_mlp(self):
        for name, module in self.mlp_modules:
            mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
            set_module_name(self.model, name, mlp)