mpt.py 3.51 KB
Newer Older
Casper's avatar
Casper committed
1
from .base import BaseAWQForCausalLM
2
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
Casper's avatar
Casper committed
3

Casper's avatar
Casper committed
4

Casper's avatar
Casper committed
5
6
class MptAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MPTBlock"
Casper's avatar
Casper committed
7
    max_seq_len_key = "max_seq_len"
Casper's avatar
Casper committed
8

9
    @staticmethod
Casper's avatar
Casper committed
10
    def fuse_layers(model: MptForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
11
        fuser = MptFuser(model)
12
        fuser.fuse_transformer()
13

14
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
15
    def get_model_layers(model: MptForCausalLM):
Casper's avatar
Casper committed
16
        return model.transformer.blocks
Casper's avatar
Casper committed
17

18
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
19
    def get_act_for_scaling(module: OldMptBlock):
20
21
22
23
        return dict(
            is_scalable=True,
            scale_name="ffn.act",
            scale_layer=module.ffn.act,
Casper's avatar
Casper committed
24
            scale_shape=module.ffn.up_proj.out_features,
25
        )
Casper's avatar
Casper committed
26

27
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
28
    def move_embed(model: MptForCausalLM, device: str):
29
30
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.emb_drop = model.transformer.emb_drop.to(device)
Casper's avatar
Casper committed
31

32
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
33
    def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
Casper's avatar
Casper committed
34
        layers = []
Casper's avatar
Casper committed
35

Casper's avatar
Casper committed
36
37
        if module_kwargs.get("output_attentions") is not None:
            module_kwargs.pop("output_attentions")
Casper's avatar
Casper committed
38
39

        # attention input
Casper's avatar
Casper committed
40
41
42
43
44
45
46
47
48
        layers.append(
            dict(
                prev_op=module.norm_1,
                layers=[module.attn.Wqkv],
                inp=input_feat["attn.Wqkv"],
                module2inspect=module.attn,
                kwargs=module_kwargs,
            )
        )
Casper's avatar
Casper committed
49
50

        # attention output
Casper's avatar
Casper committed
51
52
53
54
55
56
57
        layers.append(
            dict(
                prev_op=module.attn.Wqkv,
                layers=[module.attn.out_proj],
                inp=input_feat["attn.out_proj"],
            )
        )
Casper's avatar
Casper committed
58
59

        # linear 1
Casper's avatar
Casper committed
60
61
62
63
64
65
66
67
        layers.append(
            dict(
                prev_op=module.norm_2,
                layers=[module.ffn.up_proj],
                inp=input_feat["ffn.up_proj"],
                module2inspect=module.ffn,
            )
        )
Casper's avatar
Casper committed
68
69

        # linear 2
Casper's avatar
Casper committed
70
71
72
73
74
75
76
        layers.append(
            dict(
                prev_op=module.ffn.act,
                layers=[module.ffn.down_proj],
                inp=input_feat["ffn.down_proj"],
            )
        )
Casper's avatar
Casper committed
77

Casper Hansen's avatar
Casper Hansen committed
78
79
        return layers

Casper's avatar
Casper committed
80

Casper Hansen's avatar
Casper Hansen committed
81
82
from typing import List, Tuple
from awq.utils.utils import set_module_name
83
84
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
Casper Hansen's avatar
Casper Hansen committed
85

Casper's avatar
Casper committed
86

Casper Hansen's avatar
Casper Hansen committed
87
class MptFuser:
88
    def __init__(self, model: MptForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
89
90
        self.model = model

Casper Hansen's avatar
Casper Hansen committed
91
        self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
Casper's avatar
Casper committed
92
93
94
            (name, module)
            for name, module in self.model.named_modules()
            if "mptblock" in module.__class__.__name__.lower()
Casper Hansen's avatar
Casper Hansen committed
95
96
        ]

97
98
99
100
101
    def fuse_transformer(self):
        blocks = []

        module: OldMptBlock
        for module in self.model.transformer.blocks:
Casper's avatar
Casper committed
102
103
104
105
106
107
108
109
110
111
112
113
114
            blocks.append(
                MPTBlock(
                    self.model.config.d_model,
                    self.model.config.n_heads,
                    module.attn.Wqkv,
                    module.attn.out_proj,
                    module.ffn,
                    module.norm_1,
                    module.norm_2,
                    next(iter(module.state_dict().values())).device,
                    self.model.config.max_seq_len,
                )
            )
Casper Hansen's avatar
Casper Hansen committed
115

116
117
118
119
120
        self.model.transformer = MPTModel(
            self.model.config.vocab_size,
            blocks,
            self.model.transformer.wte,
            self.model.transformer.norm_f,
121
122
        )

Casper's avatar
Casper committed
123
        setattr(self.model.transformer, "blocks", self.model.transformer.blocks)