mistral.py 6 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import logging
Casper Hansen's avatar
Casper Hansen committed
2
from typing import Dict
Casper Hansen's avatar
Casper Hansen committed
3
from .base import BaseAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
4
5
6
7
8

try:
    from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM
except:
    # TODO: Remove once released on PyPi
Casper Hansen's avatar
Casper Hansen committed
9
    logging.warning("You need the latest transformers 4.34.0.dev0: pip install -U git+https://github.com/huggingface/transformers.git")
Casper Hansen's avatar
Casper Hansen committed
10
11
    MistralForCausalLM = None
    MistralDecoderLayer = None
Casper Hansen's avatar
Casper Hansen committed
12
13
14
15
16

class MistralAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MistralDecoderLayer"
    max_new_tokens_key = "max_position_embeddings"

Casper Hansen's avatar
Casper Hansen committed
17
18
19
20
21
22
23
    @staticmethod
    def fuse_layers(model: MistralForCausalLM, quant_config: Dict):
        fuser = MistralFuser(model, quant_config)
        fuser.fuse_attention()
        fuser.fuse_rmsnorm()
        fuser.fuse_mlp()
    
Casper Hansen's avatar
Casper Hansen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    @staticmethod
    def get_model_layers(model: MistralForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: MistralDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: MistralForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers
Casper Hansen's avatar
Casper Hansen committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP

class MistralFuser:
    def __init__(self, model, quant_config):
        self.model = model
        self.quant_config = quant_config

        self.attention_modules: List[Tuple[str, MistralAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MistralAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MistralRMSNorm)
        ]
        
        self.mlp_modules: List[Tuple[str, MistralMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MistralMLP)
        ]
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
            qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
            attn = QuantAttentionFused(
                module.hidden_size,
                module.num_heads,
                module.num_key_value_heads,
                qkv_layer, 
                module.o_proj,
                next(iter(qkv_layer.state_dict().values())).device,
                self.model.config.max_new_tokens
            )
            set_module_name(self.model, name, attn)
    
    def _fuse_qkv(self, module: MistralAttention):
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

        if isinstance(q_proj, WQLinear_GEMV):
            q_linear = WQLinear_GEMV
        else:
            q_linear = WQLinear_GEMM

        qkv_layer = q_linear(
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
            next(iter(module.state_dict().values())).device
        )

        if isinstance(qkv_layer, WQLinear_GEMV):
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
            qkv_layer.split_k_iters = q_proj.split_k_iters
        else:
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
        
        qkv_layer.bias = bias

        return qkv_layer

    def fuse_rmsnorm(self):
        for name, module in self.rmsnorm_modules:
            norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)

    def fuse_mlp(self):
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)