llama.py 5.05 KB
Newer Older
1
2
3
4
5
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
6
    max_new_tokens_key = "max_position_embeddings"
7

8
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
9
10
11
    def fuse_layers(awq_model: BaseAWQForCausalLM):
        fuser = LlamaFuser(awq_model)
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
12
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_mlp()
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
50
51
        
        # linear 1
52
53
54
55
56
57
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
58
59

        # linear 2
60
61
62
63
64
65
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
66
67
68
69
70
71
        return layers

import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
Casper Hansen's avatar
Casper Hansen committed
72
from awq.modules.fused_mlp import QuantLlamaMLP
Casper Hansen's avatar
Casper Hansen committed
73
from awq.modules.fused_norm import FTLlamaRMSNorm
Casper Hansen's avatar
Casper Hansen committed
74
from awq.modules.fused_attn import QuantLlamaAttention
Casper Hansen's avatar
Casper Hansen committed
75
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
Casper Hansen's avatar
Casper Hansen committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

class LlamaFuser:
    def __init__(self, awq_model: BaseAWQForCausalLM):
        self.awq_model = awq_model
        self.model = awq_model.model

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
95
        
        self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaMLP)
        ]
Casper Hansen's avatar
Casper Hansen committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
            qkv_layer: WQLinear = self._fuse_qkv(module)
            attn = QuantLlamaAttention(
                module.hidden_size,
                module.num_heads,
                qkv_layer,
                module.o_proj,
                qkv_layer.qweight.device,
                self.awq_model.model.config.max_new_tokens
            )
            set_module_name(self.model, name, attn)
    
    def _fuse_qkv(self, module: LlamaAttention):
        # get qkv and bias
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
        
        # create module
        qkv_layer = WQLinear(
            q_proj.w_bit, 
            q_proj.group_size, 
            q_proj.in_features, 
            q_proj.out_features + k_proj.out_features + v_proj.out_features, 
            q_proj.bias is not None,
            q_proj.qweight.device
        )

        # replace buffers with real weights
        qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
        qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
        qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
        qkv_layer.bias = bias

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
134
135
136
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
137
138

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
139
140
141
        for name, module in self.mlp_modules:
            mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
            set_module_name(self.model, name, mlp)