mixtral.py 6.11 KB
Newer Older
1
import tqdm
Casper's avatar
Casper committed
2
import torch
3
4
5
6
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
Casper's avatar
Casper committed
7
8
from awq.modules.fused.moe import FusedSparseMoeBlock
from awq.utils.fused_utils import fuse_qkv, fuse_linears
9
10
from transformers.models.mixtral.modeling_mixtral import (
    MixtralDecoderLayer as OldMixtralDecoderLayer,
Casper's avatar
Casper committed
11
    MixtralForCausalLM as OldMixtralForCausalLM,
12
)
Casper's avatar
Casper committed
13
from awq.modules.linear import WQLinear_GEMM
14
15
from awq.modules.fused.norm import FasterTransformerRMSNorm

Casper's avatar
Casper committed
16

17
18
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MixtralDecoderLayer"
Casper's avatar
Casper committed
19
20
21
    max_seq_len_key = "max_position_embeddings"
    modules_to_not_convert = ["gate"]

22
23
24
    @staticmethod
    def fuse_layers(model: OldMixtralForCausalLM):
        fuser = MixtralFuser(model)
Casper's avatar
Casper committed
25
        fuser.fuse_transformer()
Casper's avatar
Casper committed
26

27
28
29
    @staticmethod
    def get_model_layers(model: OldMixtralForCausalLM):
        return model.model.layers
Casper's avatar
Casper committed
30

31
32
    @staticmethod
    def get_act_for_scaling(module):
Casper's avatar
Casper committed
33
34
        return dict(is_scalable=False)

35
36
37
    @staticmethod
    def move_embed(model: OldMixtralForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
Casper's avatar
Casper committed
38

39
    @staticmethod
Casper's avatar
Casper committed
40
41
42
    def get_layers_for_scaling(
        module: OldMixtralDecoderLayer, input_feat, module_kwargs
    ):
43
44
45
        layers = []

        # attention input
Casper's avatar
Casper committed
46
47
48
49
50
51
52
53
54
55
56
57
58
        layers.append(
            dict(
                prev_op=module.input_layernorm,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
59
60
61

        # attention out
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
Casper's avatar
Casper committed
62
63
64
65
66
67
68
69
            layers.append(
                dict(
                    prev_op=module.self_attn.v_proj,
                    layers=[module.self_attn.o_proj],
                    inp=input_feat["self_attn.o_proj"],
                )
            )

70
        # linear in
Casper's avatar
Casper committed
71
72
73
74
75
76
77
78
79
80
81
82
        layers.append(
            dict(
                prev_op=module.post_attention_layernorm,
                layers=[
                    w
                    for expert in module.block_sparse_moe.experts
                    for w in [expert.w1, expert.w3]
                ],
                inp=input_feat["block_sparse_moe"],
                module2inspect=module.block_sparse_moe,
            )
        )
83
84
85

        # linear out
        for i, expert in enumerate(module.block_sparse_moe.experts):
Casper's avatar
Casper committed
86
87
88
89
90
91
92
            layers.append(
                dict(
                    prev_op=expert.w3,
                    layers=[expert.w2],
                    inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
                )
            )
93
94
95
96
97
98
99
100
101

        return layers


class MixtralFuser:
    def __init__(self, model: OldMixtralForCausalLM):
        self.model = model

        self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
Casper's avatar
Casper committed
102
103
104
            (name, module)
            for name, module in self.model.named_modules()
            if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
105
        ]
Casper's avatar
Casper committed
106

107
108
109
110
111
112
    def fuse_transformer(self):
        blocks = []

        module: OldMixtralDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
Casper's avatar
Casper committed
113

114
115
116
117
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
Casper's avatar
Casper committed
118
                module.self_attn.v_proj,
119
120
            )
            norm_1 = FasterTransformerRMSNorm(
Casper's avatar
Casper committed
121
                module.input_layernorm.weight, module.input_layernorm.variance_epsilon
122
            )
Casper's avatar
Casper committed
123

124
125
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
Casper's avatar
Casper committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                module.post_attention_layernorm.variance_epsilon,
            )

            sparse_moe = module.block_sparse_moe
            if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1:
                fused_w1w3s = [
                    fuse_linears(
                        [
                            sparse_moe.experts[i].w1,
                            sparse_moe.experts[i].w3,
                        ],
                        device,
                    )
                    for i in range(len(sparse_moe.experts))
                ]

                stacked_w1w3s = fuse_linears(
                    fused_w1w3s, device, dim=0, operation=torch.stack
                )

                stacked_w2s = fuse_linears(
                    [expert.w2 for expert in sparse_moe.experts],
                    device,
                    dim=0,
                    operation=torch.stack,
                )

                sparse_moe = FusedSparseMoeBlock(
                    top_k=sparse_moe.top_k,
                    gate=sparse_moe.gate,
                    ws=stacked_w1w3s,
                    w2s=stacked_w2s,
                )

            blocks.append(
                MixtralBlock(
                    hidden_size=self.model.config.hidden_size,
                    n_heads=self.model.config.num_attention_heads,
                    n_kv_heads=self.model.config.num_key_value_heads,
                    qkv_layer=qkv,
                    o_proj=module.self_attn.o_proj,
                    moe=sparse_moe,
                    norm_1=norm_1,
                    norm_2=norm_2,
                    dev=device,
                    max_seq_len=self.model.config.max_seq_len,
                    rope_theta=self.model.config.rope_theta,
                )
174
175
            )
        
Casper's avatar
Casper committed
176
177
178
179
180
        model_norm = FasterTransformerRMSNorm(
            self.model.model.norm.weight,
            self.model.model.norm.variance_epsilon,
        )

181
182
183
184
        self.model.model = MixtralModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
Casper's avatar
Casper committed
185
            model_norm,
186
        )
187
        setattr(self.model.model, "blocks", self.model.model.blocks)