llama.py 6.9 KB
Newer Older
1
2
3
4
5
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
6
    max_new_tokens_key = "max_position_embeddings"
7

8
    @staticmethod
9
10
    def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
        fuser = LlamaFuser(model, quant_config)
Casper Hansen's avatar
Casper Hansen committed
11
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
12
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_mlp()
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
50
51
        
        # linear 1
52
53
54
55
56
57
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
58
59

        # linear 2
60
61
62
63
64
65
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
66
67
68
        return layers

import torch
69
from typing import List, Tuple, Union
Casper Hansen's avatar
Casper Hansen committed
70
from awq.utils.utils import set_module_name
71
72
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
Casper Hansen's avatar
Casper Hansen committed
73
from awq.modules.fused.attn import QuantLlamaAttention, QuantLlamaAttentionFused
74
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
75
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
Casper Hansen's avatar
Casper Hansen committed
76
77

class LlamaFuser:
78
    def __init__(self, model, quant_config):
Casper Hansen's avatar
Casper Hansen committed
79
        self.model = model
80
        self.quant_config = quant_config
Casper Hansen's avatar
Casper Hansen committed
81
82
83
84
85
86
87
88
89
90

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
95
        
        self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaMLP)
        ]
Casper Hansen's avatar
Casper Hansen committed
96
97
98
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
Casper Hansen's avatar
Casper Hansen committed
99
100
101
102
103
104
105
106
107
108
109
            qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv2(module)
            # attn = QuantLlamaAttention(
            #     module.hidden_size,
            #     module.num_heads,
            #     module.num_key_value_heads,
            #     qkv_layer,
            #     module.o_proj,
            #     next(iter(qkv_layer.state_dict().values())).device,
            #     self.model.config.max_new_tokens
            # )
            attn = QuantLlamaAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
110
111
                module.hidden_size,
                module.num_heads,
Casper Hansen's avatar
Casper Hansen committed
112
                qkv_layer, 
Casper Hansen's avatar
Casper Hansen committed
113
                module.o_proj,
Casper Hansen's avatar
Casper Hansen committed
114
                next(iter(qkv_layer.state_dict().values())).device,
Casper Hansen's avatar
Casper Hansen committed
115
                self.model.config.max_new_tokens
Casper Hansen's avatar
Casper Hansen committed
116
117
118
            )
            set_module_name(self.model, name, attn)
    
Casper Hansen's avatar
Casper Hansen committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def _fuse_qkv2(self, module: LlamaAttention):
        q_proj = module.q_proj
        k_proj = module.k_proj
        v_proj = module.v_proj

        qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
        qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
        scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
        # g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
        g_idx = None
        bias = (
            torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0)
            if q_proj.bias is not None
            else None
        )

        qkv_layer = WQLinear_GEMV(
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
            q_proj.qweight.device,
        )
        qkv_layer.qweight = qweights
        qkv_layer.qzeros = qzeros
        qkv_layer.scales = scales

        qkv_layer.bias = bias
        qkv_layer.split_k_iters = q_proj.split_k_iters

        return qkv_layer
    
Casper Hansen's avatar
Casper Hansen committed
152
153
154
155
    def _fuse_qkv(self, module: LlamaAttention):
        # get qkv and bias
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
Casper Hansen's avatar
Casper Hansen committed
156

Casper Hansen's avatar
Casper Hansen committed
157
        # create module
158
159
160
161
162
163
        if self.quant_config["version"] == 'GEMM':
            qkv_module = WQLinear_GEMM
        elif self.quant_config["version"] == 'GEMV':
            qkv_module = WQLinear_GEMV
        
        qkv_layer = qkv_module(
Casper Hansen's avatar
Casper Hansen committed
164
165
166
167
168
            q_proj.w_bit, 
            q_proj.group_size, 
            q_proj.in_features, 
            q_proj.out_features + k_proj.out_features + v_proj.out_features, 
            q_proj.bias is not None,
Casper Hansen's avatar
Casper Hansen committed
169
            next(iter(module.state_dict().values())).device
Casper Hansen's avatar
Casper Hansen committed
170
171
172
        )

        # replace buffers with real weights
Casper Hansen's avatar
Casper Hansen committed
173
174
175
        qkv_layer.qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
        qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
        qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
Casper Hansen's avatar
Casper Hansen committed
176
        qkv_layer.bias = bias
Casper Hansen's avatar
Casper Hansen committed
177
        qkv_layer.split_k_iters = q_proj.split_k_iters
Casper Hansen's avatar
Casper Hansen committed
178
179
180
181

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
182
183
184
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
185
186

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
187
188
189
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)