mistral.py 4.29 KB
Newer Older
Casper's avatar
Casper committed
1
2
import tqdm
from typing import List, Tuple
Casper Hansen's avatar
Casper Hansen committed
3
from .base import BaseAWQForCausalLM
Casper's avatar
Casper committed
4
5
6
7
8
9
10
11
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
    MistralDecoderLayer as OldMistralDecoderLayer,
    MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
Casper Hansen's avatar
Casper Hansen committed
12
13
14
15
16

class MistralAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MistralDecoderLayer"
    max_new_tokens_key = "max_position_embeddings"

Casper Hansen's avatar
Casper Hansen committed
17
    @staticmethod
Casper's avatar
Casper committed
18
    def fuse_layers(model: OldMistralForCausalLM):
Casper's avatar
Casper committed
19
        fuser = MistralFuser(model)
Casper's avatar
Casper committed
20
21
        fuser.fuse_transformer()

Casper Hansen's avatar
Casper Hansen committed
22
    @staticmethod
Casper's avatar
Casper committed
23
    def get_model_layers(model: OldMistralForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
24
25
26
        return model.model.layers
    
    @staticmethod
Casper's avatar
Casper committed
27
    def get_act_for_scaling(module: OldMistralDecoderLayer):
Casper Hansen's avatar
Casper Hansen committed
28
29
30
31
32
        return dict(
            is_scalable=False
        )
    
    @staticmethod
Casper's avatar
Casper committed
33
    def move_embed(model: OldMistralForCausalLM, device: str):
Casper Hansen's avatar
Casper Hansen committed
34
35
36
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
Casper's avatar
Casper committed
37
    def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs):
Casper Hansen's avatar
Casper Hansen committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers
Casper Hansen's avatar
Casper Hansen committed
74
75
76


class MistralFuser:
Casper's avatar
Casper committed
77
    def __init__(self, model: OldMistralForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
78
79
        self.model = model

Casper's avatar
Casper committed
80
        self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
Casper Hansen's avatar
Casper Hansen committed
81
            (name, module) for name, module in self.model.named_modules()
Casper's avatar
Casper committed
82
            if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower()
Casper Hansen's avatar
Casper Hansen committed
83
84
        ]
    
Casper's avatar
Casper committed
85
86
87
88
89
90
91
92
93
94
95
    def fuse_transformer(self):
        blocks = []

        module: OldMistralDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj
Casper Hansen's avatar
Casper Hansen committed
96
            )
Casper's avatar
Casper committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            norm_1 = FasterTransformerRMSNorm(
                module.input_layernorm.weight,
                module.input_layernorm.variance_epsilon
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
                module.post_attention_layernorm.variance_epsilon
            )
            blocks.append(LlamaLikeBlock(
                hidden_size=self.model.config.hidden_size,
                n_heads=self.model.config.num_attention_heads,
                n_kv_heads=self.model.config.num_key_value_heads,
                qkv_layer=qkv,
                o_proj=module.self_attn.o_proj,
111
                mlp=module.mlp,
Casper's avatar
Casper committed
112
113
114
115
116
                norm_1=norm_1,
                norm_2=norm_2,
                dev=device,
                max_seq_len=self.model.config.max_new_tokens
            ))
Casper Hansen's avatar
Casper Hansen committed
117
        
Casper's avatar
Casper committed
118
119
120
121
122
123
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
124
        setattr(self.model.model, "blocks", self.model.model.blocks)