llama.py 5.09 KB
Newer Older
1
2
3
4
5
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
6
    max_new_tokens_key = "max_position_embeddings"
7

8
    @staticmethod
9
10
    def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
        fuser = LlamaFuser(model, quant_config)
Casper Hansen's avatar
Casper Hansen committed
11
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
12
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_mlp()
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
50
51
        
        # linear 1
52
53
54
55
56
57
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
58
59

        # linear 2
60
61
62
63
64
65
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
66
67
68
        return layers

import torch
69
from typing import List, Tuple, Union
Casper Hansen's avatar
Casper Hansen committed
70
from awq.utils.utils import set_module_name
71
72
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
73
from awq.modules.fused.attn import QuantLlamaAttentionFused
74
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
75
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
Casper Hansen's avatar
Casper Hansen committed
76
77

class LlamaFuser:
78
    def __init__(self, model, quant_config):
Casper Hansen's avatar
Casper Hansen committed
79
        self.model = model
80
        self.quant_config = quant_config
Casper Hansen's avatar
Casper Hansen committed
81
82
83
84
85
86
87
88
89
90

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
95
        
        self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaMLP)
        ]
Casper Hansen's avatar
Casper Hansen committed
96
97
98
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
99
            qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
Casper Hansen's avatar
Casper Hansen committed
100
            attn = QuantLlamaAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
101
102
                module.hidden_size,
                module.num_heads,
Casper Hansen's avatar
Casper Hansen committed
103
                qkv_layer, 
Casper Hansen's avatar
Casper Hansen committed
104
                module.o_proj,
Casper Hansen's avatar
Casper Hansen committed
105
                next(iter(qkv_layer.state_dict().values())).device,
Casper Hansen's avatar
Casper Hansen committed
106
                self.model.config.max_new_tokens
Casper Hansen's avatar
Casper Hansen committed
107
108
109
            )
            set_module_name(self.model, name, attn)
    
110
111
112
    def _fuse_qkv(self, module: LlamaAttention):
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
Casper Hansen's avatar
Casper Hansen committed
113
114
115
116
117
118
119
120
121

        qkv_layer = WQLinear_GEMV(
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
            q_proj.qweight.device,
        )
122
        qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
Casper Hansen's avatar
Casper Hansen committed
123
124
        qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
        qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
125

Casper Hansen's avatar
Casper Hansen committed
126
        qkv_layer.bias = bias
Casper Hansen's avatar
Casper Hansen committed
127
        qkv_layer.split_k_iters = q_proj.split_k_iters
Casper Hansen's avatar
Casper Hansen committed
128
129
130
131

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
132
133
134
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
135
136

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
137
138
139
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)