mistral.py 4.42 KB
Newer Older
Casper's avatar
Casper committed
1
2
import tqdm
from typing import List, Tuple
Casper Hansen's avatar
Casper Hansen committed
3
from .base import BaseAWQForCausalLM
Casper's avatar
Casper committed
4
5
6
7
8
9
10
11
12
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
    MistralDecoderLayer as OldMistralDecoderLayer,
    MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
Casper Hansen's avatar
Casper Hansen committed
13
14
15
16
17

class MistralAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MistralDecoderLayer"
    max_new_tokens_key = "max_position_embeddings"

Casper Hansen's avatar
Casper Hansen committed
18
    @staticmethod
Casper's avatar
Casper committed
19
    def fuse_layers(model: OldMistralForCausalLM):
Casper's avatar
Casper committed
20
        fuser = MistralFuser(model)
Casper's avatar
Casper committed
21
22
        fuser.fuse_transformer()

Casper Hansen's avatar
Casper Hansen committed
23
    @staticmethod
Casper's avatar
Casper committed
24
    def get_model_layers(model: OldMistralForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
25
26
27
        return model.model.layers
    
    @staticmethod
Casper's avatar
Casper committed
28
    def get_act_for_scaling(module: OldMistralDecoderLayer):
Casper Hansen's avatar
Casper Hansen committed
29
30
31
32
33
        return dict(
            is_scalable=False
        )
    
    @staticmethod
Casper's avatar
Casper committed
34
    def move_embed(model: OldMistralForCausalLM, device: str):
Casper Hansen's avatar
Casper Hansen committed
35
36
37
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
Casper's avatar
Casper committed
38
    def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs):
Casper Hansen's avatar
Casper Hansen committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers
Casper Hansen's avatar
Casper Hansen committed
75
76
77


class MistralFuser:
Casper's avatar
Casper committed
78
    def __init__(self, model: OldMistralForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
79
80
        self.model = model

Casper's avatar
Casper committed
81
        self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
Casper Hansen's avatar
Casper Hansen committed
82
            (name, module) for name, module in self.model.named_modules()
Casper's avatar
Casper committed
83
            if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower()
Casper Hansen's avatar
Casper Hansen committed
84
85
        ]
    
Casper's avatar
Casper committed
86
87
88
89
90
91
92
93
94
95
96
    def fuse_transformer(self):
        blocks = []

        module: OldMistralDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj
Casper Hansen's avatar
Casper Hansen committed
97
            )
Casper's avatar
Casper committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            mlp = QuantLlamaMLP(
                module.mlp.gate_proj,
                module.mlp.down_proj,
                module.mlp.up_proj
            )
            norm_1 = FasterTransformerRMSNorm(
                module.input_layernorm.weight,
                module.input_layernorm.variance_epsilon
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
                module.post_attention_layernorm.variance_epsilon
            )
            blocks.append(LlamaLikeBlock(
                hidden_size=self.model.config.hidden_size,
                n_heads=self.model.config.num_attention_heads,
                n_kv_heads=self.model.config.num_key_value_heads,
                qkv_layer=qkv,
                o_proj=module.self_attn.o_proj,
                mlp=mlp,
                norm_1=norm_1,
                norm_2=norm_2,
                dev=device,
                max_seq_len=self.model.config.max_new_tokens
            ))
Casper Hansen's avatar
Casper Hansen committed
123
        
Casper's avatar
Casper committed
124
125
126
127
128
129
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )