aquila.py 4.39 KB
Newer Older
Casper's avatar
Casper committed
1
2
import tqdm
from typing import List, Tuple
ldwang's avatar
ldwang committed
3
from .base import BaseAWQForCausalLM
Casper's avatar
Casper committed
4
5
6
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
ldwang's avatar
ldwang committed
7
from transformers.models.llama.modeling_llama import (
Casper's avatar
Casper committed
8
9
    LlamaDecoderLayer as OldAquilaDecoderLayer,
    LlamaForCausalLM as OldAquilaForCausalLM
ldwang's avatar
ldwang committed
10
)
Casper's avatar
Casper committed
11
12
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
ldwang's avatar
ldwang committed
13
14
15
16
17
18

class AquilaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "AquilaDecoderLayer"
    max_new_tokens_key = "max_position_embeddings"

    @staticmethod
Casper's avatar
Casper committed
19
    def fuse_layers(model: OldAquilaForCausalLM):
Casper's avatar
Casper committed
20
        fuser = AquilaFuser(model)
Casper's avatar
Casper committed
21
        fuser.fuse_transformer()
ldwang's avatar
ldwang committed
22
23

    @staticmethod
Casper's avatar
Casper committed
24
    def get_model_layers(model: OldAquilaForCausalLM):
ldwang's avatar
ldwang committed
25
26
27
        return model.model.layers
    
    @staticmethod
Casper's avatar
Casper committed
28
    def get_act_for_scaling(module: OldAquilaDecoderLayer):
ldwang's avatar
ldwang committed
29
30
31
32
33
        return dict(
            is_scalable=False
        )
    
    @staticmethod
Casper's avatar
Casper committed
34
    def move_embed(model: OldAquilaForCausalLM, device: str):
ldwang's avatar
ldwang committed
35
36
37
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
Casper's avatar
Casper committed
38
    def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs):
ldwang's avatar
ldwang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers


class AquilaFuser:
Casper's avatar
Casper committed
78
    def __init__(self, model: OldAquilaForCausalLM):
ldwang's avatar
ldwang committed
79
80
        self.model = model

Casper's avatar
Casper committed
81
        self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
ldwang's avatar
ldwang committed
82
            (name, module) for name, module in self.model.named_modules()
Casper's avatar
Casper committed
83
            if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower()
ldwang's avatar
ldwang committed
84
85
        ]
    
Casper's avatar
Casper committed
86
87
88
89
90
91
92
93
94
95
96
    def fuse_transformer(self):
        blocks = []

        module: OldAquilaDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj
ldwang's avatar
ldwang committed
97
            )
Casper's avatar
Casper committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            mlp = QuantLlamaMLP(
                module.mlp.gate_proj,
                module.mlp.down_proj,
                module.mlp.up_proj
            )
            norm_1 = FasterTransformerRMSNorm(
                module.input_layernorm.weight,
                module.input_layernorm.variance_epsilon
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
                module.post_attention_layernorm.variance_epsilon
            )
            blocks.append(LlamaLikeBlock(
                hidden_size=self.model.config.hidden_size,
                n_heads=self.model.config.num_attention_heads,
                n_kv_heads=self.model.config.num_key_value_heads,
                qkv_layer=qkv,
                o_proj=module.self_attn.o_proj,
                mlp=mlp,
                norm_1=norm_1,
                norm_2=norm_2,
                dev=device,
                max_seq_len=self.model.config.max_new_tokens
            ))
ldwang's avatar
ldwang committed
123
        
Casper's avatar
Casper committed
124
125
126
127
128
129
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )