llama.py 4.7 KB
Newer Older
1
from .base import BaseAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
2
from awq.modules import make_fused_mlp
3
4
5
6
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

class LlamaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
7
    max_new_tokens_key = "max_position_embeddings"
8

9
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
10
11
12
    def fuse_layers(awq_model: BaseAWQForCausalLM):
        fuser = LlamaFuser(awq_model)
        fuser.fuse_attention()
Casper Hansen's avatar
Casper Hansen committed
13
        fuser.fuse_rmsnorm()
Casper Hansen's avatar
Casper Hansen committed
14
        make_fused_mlp(awq_model)#fuser.fuse_mlp()
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    @staticmethod
    def get_model_layers(model: LlamaForCausalLM):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module: LlamaDecoderLayer):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model: LlamaForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
Casper Hansen's avatar
Casper Hansen committed
51
52
        
        # linear 1
53
54
55
56
57
58
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))
Casper Hansen's avatar
Casper Hansen committed
59
60

        # linear 2
61
62
63
64
65
66
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
71
72
        return layers

import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
Casper Hansen's avatar
Casper Hansen committed
73
from awq.modules.fused_norm import FTLlamaRMSNorm
Casper Hansen's avatar
Casper Hansen committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm

class LlamaFuser:
    def __init__(self, awq_model: BaseAWQForCausalLM):
        self.awq_model = awq_model
        self.model = awq_model.model

        self.attention_modules: List[Tuple[str, LlamaAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaAttention)
        ]

        self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LlamaRMSNorm)
        ]
    
    def fuse_attention(self):
        for name, module in self.attention_modules:
            qkv_layer: WQLinear = self._fuse_qkv(module)
            attn = QuantLlamaAttention(
                module.hidden_size,
                module.num_heads,
                qkv_layer,
                module.o_proj,
                qkv_layer.qweight.device,
                self.awq_model.model.config.max_new_tokens
            )
            set_module_name(self.model, name, attn)
    
    def _fuse_qkv(self, module: LlamaAttention):
        # get qkv and bias
        q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
        
        # create module
        qkv_layer = WQLinear(
            q_proj.w_bit, 
            q_proj.group_size, 
            q_proj.in_features, 
            q_proj.out_features + k_proj.out_features + v_proj.out_features, 
            q_proj.bias is not None,
            q_proj.qweight.device
        )

        # replace buffers with real weights
        qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
        qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
        qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
        qkv_layer.bias = bias

        return qkv_layer

    def fuse_rmsnorm(self):
Casper Hansen's avatar
Casper Hansen committed
129
130
131
        for name, module in self.rmsnorm_modules:
            norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
132
133
134

    def fuse_mlp(self):
        pass