qwen2.py 4.5 KB
Newer Older
Junyang Lin's avatar
Junyang Lin committed
1
2
3
4
5
6
7
8
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2DecoderLayer as OldQwen2DecoderLayer,
Casper's avatar
Casper committed
9
    Qwen2ForCausalLM as OldQwen2ForCausalLM,
Junyang Lin's avatar
Junyang Lin committed
10
11
12
13
14
15
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "Qwen2DecoderLayer"
Casper's avatar
Casper committed
16
    max_seq_len_key = "max_position_embeddings"
Junyang Lin's avatar
Junyang Lin committed
17
18
19
20
21
22
23
24
25
26
27
28

    @staticmethod
    def fuse_layers(model: OldQwen2ForCausalLM):
        fuser = Qwen2Fuser(model)
        fuser.fuse_transformer()

    @staticmethod
    def get_model_layers(model: OldQwen2ForCausalLM):
        return model.model.layers

    @staticmethod
    def get_act_for_scaling(module: OldQwen2DecoderLayer):
Casper's avatar
Casper committed
29
        return dict(is_scalable=False)
Junyang Lin's avatar
Junyang Lin committed
30
31
32
33
34
35
36
37
38
39

    @staticmethod
    def move_embed(model: OldQwen2ForCausalLM, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)

    @staticmethod
    def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
Casper's avatar
Casper committed
40
41
42
43
44
45
46
47
48
49
50
51
52
        layers.append(
            dict(
                prev_op=module.input_layernorm,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
Junyang Lin's avatar
Junyang Lin committed
53
54
55
56

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
Casper's avatar
Casper committed
57
58
59
60
61
62
63
            layers.append(
                dict(
                    prev_op=module.self_attn.v_proj,
                    layers=[module.self_attn.o_proj],
                    inp=input_feat["self_attn.o_proj"],
                )
            )
Junyang Lin's avatar
Junyang Lin committed
64
65

        # linear 1
Casper's avatar
Casper committed
66
67
68
69
70
71
72
73
        layers.append(
            dict(
                prev_op=module.post_attention_layernorm,
                layers=[module.mlp.gate_proj, module.mlp.up_proj],
                inp=input_feat["mlp.gate_proj"],
                module2inspect=module.mlp,
            )
        )
Junyang Lin's avatar
Junyang Lin committed
74
75

        # linear 2
Casper's avatar
Casper committed
76
77
78
79
80
81
82
        layers.append(
            dict(
                prev_op=module.mlp.up_proj,
                layers=[module.mlp.down_proj],
                inp=input_feat["mlp.down_proj"],
            )
        )
Junyang Lin's avatar
Junyang Lin committed
83
84
85
86
87
88
89
90
91

        return layers


class Qwen2Fuser:
    def __init__(self, model: OldQwen2ForCausalLM):
        self.model = model

        self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [
Casper's avatar
Casper committed
92
93
94
            (name, module)
            for name, module in self.model.named_modules()
            if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower()
Junyang Lin's avatar
Junyang Lin committed
95
96
97
98
99
100
101
102
103
104
105
106
        ]

    def fuse_transformer(self):
        blocks = []

        module: OldQwen2DecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
Casper's avatar
Casper committed
107
                module.self_attn.v_proj,
Junyang Lin's avatar
Junyang Lin committed
108
109
            )
            norm_1 = FasterTransformerRMSNorm(
Casper's avatar
Casper committed
110
                module.input_layernorm.weight, module.input_layernorm.variance_epsilon
Junyang Lin's avatar
Junyang Lin committed
111
112
113
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
Casper's avatar
Casper committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                module.post_attention_layernorm.variance_epsilon,
            )
            blocks.append(
                LlamaLikeBlock(
                    hidden_size=self.model.config.hidden_size,
                    n_heads=self.model.config.num_attention_heads,
                    n_kv_heads=self.model.config.num_key_value_heads,
                    qkv_layer=qkv,
                    o_proj=module.self_attn.o_proj,
                    mlp=module.mlp,
                    norm_1=norm_1,
                    norm_2=norm_2,
                    dev=device,
                    max_seq_len=self.model.config.max_seq_len,
                )
Junyang Lin's avatar
Junyang Lin committed
129
130
131
132
133
134
135
136
137
            )

        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
        setattr(self.model.model, "blocks", self.model.model.blocks)