llama.py 5.67 KB
Newer Older
1
from .base import BaseAWQForCausalLM
Vik Paruchuri's avatar
Vik Paruchuri committed
2
from typing import Dict
3
4
5
6
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
7
    max_new_tokens_key = "max_position_embeddings"
8

9
    @staticmethod
Vik Paruchuri's avatar
Vik Paruchuri committed
10
    def fuse_layers(model: LlamaForCausalLM, quant_config: Dict):
11
        fuser = LlamaFuser(model, quant_config)
Casper Hansen's avatar
Casper Hansen committed
12
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
14
        fuser.fuse_mlp()
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
51
52
        
        # linear 1
53
54
55
56
57
58
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
59
60

        # linear 2
61
62
63
64
65
66
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
67
68
69
        return layers

import torch
70
from typing import List, Tuple, Union
Casper Hansen's avatar
Casper Hansen committed
71
from awq.utils.utils import set_module_name
72
73
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
Casper Hansen's avatar
Casper Hansen committed
74
from awq.modules.fused.attn import QuantAttentionFused
75
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
76
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
Casper Hansen's avatar
Casper Hansen committed
77
78

class LlamaFuser:
79
    def __init__(self, model, quant_config):
Casper Hansen's avatar
Casper Hansen committed
80
        self.model = model
81
        self.quant_config = quant_config
Casper Hansen's avatar
Casper Hansen committed
82
83
84
85
86
87
88
89
90
91

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
Casper Hansen's avatar
Casper Hansen committed
92
93
94
95
96
        
        self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaMLP)
        ]
Casper Hansen's avatar
Casper Hansen committed
97
98
99
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
100
            qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
Casper Hansen's avatar
Casper Hansen committed
101
            attn = QuantAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
102
103
                module.hidden_size,
                module.num_heads,
Casper Hansen's avatar
Casper Hansen committed
104
                module.num_key_value_heads,
Casper Hansen's avatar
Casper Hansen committed
105
                qkv_layer, 
Casper Hansen's avatar
Casper Hansen committed
106
                module.o_proj,
Casper Hansen's avatar
Casper Hansen committed
107
                next(iter(qkv_layer.state_dict().values())).device,
Casper Hansen's avatar
Casper Hansen committed
108
                self.model.config.max_new_tokens
Casper Hansen's avatar
Casper Hansen committed
109
110
111
            )
            set_module_name(self.model, name, attn)
    
112
113
114
    def _fuse_qkv(self, module: LlamaAttention):
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
Casper Hansen's avatar
Casper Hansen committed
115

Casper Hansen's avatar
Casper Hansen committed
116
117
118
119
120
121
        if isinstance(q_proj, WQLinear_GEMV):
            q_linear = WQLinear_GEMV
        else:
            q_linear = WQLinear_GEMM

        qkv_layer = q_linear(
Casper Hansen's avatar
Casper Hansen committed
122
123
124
125
126
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
Casper Hansen's avatar
Casper Hansen committed
127
            next(iter(module.state_dict().values())).device
Casper Hansen's avatar
Casper Hansen committed
128
        )
129

Casper Hansen's avatar
Casper Hansen committed
130
131
132
133
134
135
136
137
138
139
        if isinstance(qkv_layer, WQLinear_GEMV):
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
            qkv_layer.split_k_iters = q_proj.split_k_iters
        else:
            qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
            qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
            qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
        
Casper Hansen's avatar
Casper Hansen committed
140
141
142
143
144
        qkv_layer.bias = bias

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
145
146
147
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
148
149

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
150
151
152
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)