aquila.py 4.27 KB
Newer Older
Casper's avatar
Casper committed
1
2
import tqdm
from typing import List, Tuple
ldwang's avatar
ldwang committed
3
from .base import BaseAWQForCausalLM
Casper's avatar
Casper committed
4
5
6
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
ldwang's avatar
ldwang committed
7
from transformers.models.llama.modeling_llama import (
Casper's avatar
Casper committed
8
9
    LlamaDecoderLayer as OldAquilaDecoderLayer,
    LlamaForCausalLM as OldAquilaForCausalLM
ldwang's avatar
ldwang committed
10
)
Casper's avatar
Casper committed
11
from awq.modules.fused.norm import FasterTransformerRMSNorm
ldwang's avatar
ldwang committed
12
13
14
15
16
17

class AquilaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "AquilaDecoderLayer"
    max_new_tokens_key = "max_position_embeddings"

    @staticmethod
Casper's avatar
Casper committed
18
    def fuse_layers(model: OldAquilaForCausalLM):
Casper's avatar
Casper committed
19
        fuser = AquilaFuser(model)
Casper's avatar
Casper committed
20
        fuser.fuse_transformer()
ldwang's avatar
ldwang committed
21
22

    @staticmethod
Casper's avatar
Casper committed
23
    def get_model_layers(model: OldAquilaForCausalLM):
ldwang's avatar
ldwang committed
24
25
26
        return model.model.layers
    
    @staticmethod
Casper's avatar
Casper committed
27
    def get_act_for_scaling(module: OldAquilaDecoderLayer):
ldwang's avatar
ldwang committed
28
29
30
31
32
        return dict(
            is_scalable=False
        )
    
    @staticmethod
Casper's avatar
Casper committed
33
    def move_embed(model: OldAquilaForCausalLM, device: str):
ldwang's avatar
ldwang committed
34
35
36
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
Casper's avatar
Casper committed
37
    def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs):
ldwang's avatar
ldwang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.q_proj,
                    module.self_attn.k_proj, module.self_attn.v_proj],
            inp=input_feat['self_attn.q_proj'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            layers.append(dict(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.o_proj],
                inp=input_feat['self_attn.o_proj'],
            ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers


class AquilaFuser:
Casper's avatar
Casper committed
77
    def __init__(self, model: OldAquilaForCausalLM):
ldwang's avatar
ldwang committed
78
79
        self.model = model

Casper's avatar
Casper committed
80
        self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
ldwang's avatar
ldwang committed
81
            (name, module) for name, module in self.model.named_modules()
Casper's avatar
Casper committed
82
            if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower()
ldwang's avatar
ldwang committed
83
84
        ]
    
Casper's avatar
Casper committed
85
86
87
88
89
90
91
92
93
94
95
    def fuse_transformer(self):
        blocks = []

        module: OldAquilaDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
                module.self_attn.v_proj
ldwang's avatar
ldwang committed
96
            )
Casper's avatar
Casper committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            norm_1 = FasterTransformerRMSNorm(
                module.input_layernorm.weight,
                module.input_layernorm.variance_epsilon
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
                module.post_attention_layernorm.variance_epsilon
            )
            blocks.append(LlamaLikeBlock(
                hidden_size=self.model.config.hidden_size,
                n_heads=self.model.config.num_attention_heads,
                n_kv_heads=self.model.config.num_key_value_heads,
                qkv_layer=qkv,
                o_proj=module.self_attn.o_proj,
111
                mlp=module.mlp,
Casper's avatar
Casper committed
112
113
114
115
116
                norm_1=norm_1,
                norm_2=norm_2,
                dev=device,
                max_seq_len=self.model.config.max_new_tokens
            ))
ldwang's avatar
ldwang committed
117
        
Casper's avatar
Casper committed
118
119
120
121
122
123
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
124
125

        setattr(self.model.model, "blocks", self.model.model.blocks)