llama.py 5.64 KB
Newer Older
1
2
3
4
5
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
6
    max_new_tokens_key = "max_position_embeddings"
7

8
    @staticmethod
9
10
    def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
        fuser = LlamaFuser(model, quant_config)
Casper Hansen's avatar
Casper Hansen committed
11
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
12
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_mlp()
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
50
51
        
        # linear 1
52
53
54
55
56
57
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
58
59

        # linear 2
60
61
62
63
64
65
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
66
67
68
        return layers

import torch
69
from typing import List, Tuple, Union
Casper Hansen's avatar
Casper Hansen committed
70
from awq.utils.utils import set_module_name
71
72
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
Casper Hansen's avatar
Casper Hansen committed
73
from awq.modules.fused.attn import QuantAttentionFused
74
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
75
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
Casper Hansen's avatar
Casper Hansen committed
76
77

class LlamaFuser:
78
    def __init__(self, model, quant_config):
Casper Hansen's avatar
Casper Hansen committed
79
        self.model = model
80
        self.quant_config = quant_config
Casper Hansen's avatar
Casper Hansen committed
81
82
83
84
85
86
87
88
89
90

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
95
        
        self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaMLP)
        ]
Casper Hansen's avatar
Casper Hansen committed
96
97
98
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
99
            qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
Casper Hansen's avatar
Casper Hansen committed
100
            attn = QuantAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
101
102
                module.hidden_size,
                module.num_heads,
Casper Hansen's avatar
Casper Hansen committed
103
                module.num_key_value_heads,
Casper Hansen's avatar
Casper Hansen committed
104
                qkv_layer, 
Casper Hansen's avatar
Casper Hansen committed
105
                module.o_proj,
Casper Hansen's avatar
Casper Hansen committed
106
                next(iter(qkv_layer.state_dict().values())).device,
Casper Hansen's avatar
Casper Hansen committed
107
                self.model.config.max_new_tokens
Casper Hansen's avatar
Casper Hansen committed
108
109
110
            )
            set_module_name(self.model, name, attn)
    
111
112
113
    def _fuse_qkv(self, module: LlamaAttention):
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
Casper Hansen's avatar
Casper Hansen committed
114

Casper Hansen's avatar
Casper Hansen committed
115
116
117
118
119
120
        if isinstance(q_proj, WQLinear_GEMV):
            q_linear = WQLinear_GEMV
        else:
            q_linear = WQLinear_GEMM

        qkv_layer = q_linear(
Casper Hansen's avatar
Casper Hansen committed
121
122
123
124
125
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
Casper Hansen's avatar
Casper Hansen committed
126
            next(iter(module.state_dict().values())).device
Casper Hansen's avatar
Casper Hansen committed
127
        )
128

Casper Hansen's avatar
Casper Hansen committed
129
130
131
132
133
134
135
136
137
138
        if isinstance(qkv_layer, WQLinear_GEMV):
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
            qkv_layer.split_k_iters = q_proj.split_k_iters
        else:
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
        
Casper Hansen's avatar
Casper Hansen committed
139
140
141
142
143
        qkv_layer.bias = bias

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
144
145
146
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
147
148

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
149
150
151
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)