yi.py 4.11 KB
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.norm import FasterTransformerRMSNorm

Casper's avatar
Casper committed
9

Casper's avatar
Casper committed
10
11
class YiAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "YiDecoderLayer"
Casper's avatar
Casper committed
12
    max_seq_len_key = "max_position_embeddings"
Casper's avatar
Casper committed
13
14
15
16
17
18
19
20
21

    @staticmethod
    def fuse_layers(model):
        fuser = YiFuser(model)
        fuser.fuse_transformer()

    @staticmethod
    def get_model_layers(model):
        return model.model.layers
Casper's avatar
Casper committed
22

Casper's avatar
Casper committed
23
24
    @staticmethod
    def get_act_for_scaling(module):
Casper's avatar
Casper committed
25
26
        return dict(is_scalable=False)

Casper's avatar
Casper committed
27
28
29
    @staticmethod
    def move_embed(model, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
Casper's avatar
Casper committed
30

Casper's avatar
Casper committed
31
32
33
34
35
    @staticmethod
    def get_layers_for_scaling(module, input_feat, module_kwargs):
        layers = []

        # attention input
Casper's avatar
Casper committed
36
37
38
39
40
41
42
43
44
45
46
47
48
        layers.append(
            dict(
                prev_op=module.ln1,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
Casper's avatar
Casper committed
49
50
51
52

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
Casper's avatar
Casper committed
53
54
55
56
57
58
59
60
            layers.append(
                dict(
                    prev_op=module.self_attn.v_proj,
                    layers=[module.self_attn.o_proj],
                    inp=input_feat["self_attn.o_proj"],
                )
            )

Casper's avatar
Casper committed
61
        # linear 1
Casper's avatar
Casper committed
62
63
64
65
66
67
68
69
        layers.append(
            dict(
                prev_op=module.ln2,
                layers=[module.mlp.gate_proj, module.mlp.up_proj],
                inp=input_feat["mlp.gate_proj"],
                module2inspect=module.mlp,
            )
        )
Casper's avatar
Casper committed
70
71

        # linear 2
Casper's avatar
Casper committed
72
73
74
75
76
77
78
        layers.append(
            dict(
                prev_op=module.mlp.up_proj,
                layers=[module.mlp.down_proj],
                inp=input_feat["mlp.down_proj"],
            )
        )
Casper's avatar
Casper committed
79
80
81
82
83
84
85
86
87

        return layers


class YiFuser:
    def __init__(self, model):
        self.model = model

        self.yi_blocks: List[Tuple[str, object]] = [
Casper's avatar
Casper committed
88
89
90
            (name, module)
            for name, module in self.model.named_modules()
            if "YiDecoderLayer".lower() in module.__class__.__name__.lower()
Casper's avatar
Casper committed
91
        ]
Casper's avatar
Casper committed
92

Casper's avatar
Casper committed
93
94
95
96
97
98
99
100
101
    def fuse_transformer(self):
        blocks = []

        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
Casper's avatar
Casper committed
102
                module.self_attn.v_proj,
Casper's avatar
Casper committed
103
104
            )
            norm_1 = FasterTransformerRMSNorm(
Casper's avatar
Casper committed
105
                module.ln1.weight, module.ln1.variance_epsilon
Casper's avatar
Casper committed
106
107
            )
            norm_2 = FasterTransformerRMSNorm(
Casper's avatar
Casper committed
108
                module.ln2.weight, module.ln2.variance_epsilon
Casper's avatar
Casper committed
109
            )
Casper's avatar
Casper committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            blocks.append(
                LlamaLikeBlock(
                    hidden_size=self.model.config.hidden_size,
                    n_heads=self.model.config.num_attention_heads,
                    n_kv_heads=self.model.config.num_key_value_heads,
                    qkv_layer=qkv,
                    o_proj=module.self_attn.o_proj,
                    mlp=module.mlp,
                    norm_1=norm_1,
                    norm_2=norm_2,
                    dev=device,
                    max_seq_len=self.model.config.max_seq_len,
                    rope_theta=self.model.config.rope_theta,
                )
            )

Casper's avatar
Casper committed
126
127
128
129
130
131
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
132
        setattr(self.model.model, "blocks", self.model.model.blocks)