llava.py 4.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
    LlamaDecoderLayer as OldLlamaDecoderLayer,
)
Casper's avatar
Casper committed
10
11
12
from transformers.models.llava.modeling_llava import (
    LlavaForConditionalGeneration as OldLlavaForConditionalGeneration,
)
13
14
from awq.modules.fused.norm import FasterTransformerRMSNorm

Casper's avatar
Casper committed
15

16
17
class LlavaAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "LlamaDecoderLayer"
Casper's avatar
Casper committed
18
    max_seq_len_key = "max_position_embeddings"
19
20
21
22
23
24
25
26
27

    @staticmethod
    def fuse_layers(model: OldLlavaForConditionalGeneration):
        fuser = LlavaFuser(model)
        fuser.fuse_transformer()

    @staticmethod
    def get_model_layers(model: OldLlavaForConditionalGeneration):
        return model.language_model.model.layers
Casper's avatar
Casper committed
28

29
30
    @staticmethod
    def get_act_for_scaling(module: OldLlamaDecoderLayer):
Casper's avatar
Casper committed
31
32
        return dict(is_scalable=False)

33
34
    @staticmethod
    def move_embed(model: OldLlavaForConditionalGeneration, device: str):
Casper's avatar
Casper committed
35
36
37
38
        model.language_model.model.embed_tokens = model.get_input_embeddings().to(
            device
        )

39
40
41
42
43
    @staticmethod
    def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
        layers = []

        # attention input
Casper's avatar
Casper committed
44
45
46
47
48
49
50
51
52
53
54
55
56
        layers.append(
            dict(
                prev_op=module.input_layernorm,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
57
58
59
60

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
Casper's avatar
Casper committed
61
62
63
64
65
66
67
68
            layers.append(
                dict(
                    prev_op=module.self_attn.v_proj,
                    layers=[module.self_attn.o_proj],
                    inp=input_feat["self_attn.o_proj"],
                )
            )

69
        # linear 1
Casper's avatar
Casper committed
70
71
72
73
74
75
76
77
        layers.append(
            dict(
                prev_op=module.post_attention_layernorm,
                layers=[module.mlp.gate_proj, module.mlp.up_proj],
                inp=input_feat["mlp.gate_proj"],
                module2inspect=module.mlp,
            )
        )
78
79

        # linear 2
Casper's avatar
Casper committed
80
81
82
83
84
85
86
        layers.append(
            dict(
                prev_op=module.mlp.up_proj,
                layers=[module.mlp.down_proj],
                inp=input_feat["mlp.down_proj"],
            )
        )
87
88
89
90
91
92
93
94
95

        return layers


class LlavaFuser:
    def __init__(self, model: OldLlavaForConditionalGeneration):
        self.model = model.language_model

        self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
Casper's avatar
Casper committed
96
97
98
            (name, module)
            for name, module in self.model.named_modules()
            if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
99
        ]
Casper's avatar
Casper committed
100

101
102
103
104
105
106
107
108
109
110
    def fuse_transformer(self):
        blocks = []

        module: OldLlamaDecoderLayer
        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            qkv = fuse_qkv(
                module,
                module.self_attn.q_proj,
                module.self_attn.k_proj,
Casper's avatar
Casper committed
111
                module.self_attn.v_proj,
112
113
            )
            norm_1 = FasterTransformerRMSNorm(
Casper's avatar
Casper committed
114
                module.input_layernorm.weight, module.input_layernorm.variance_epsilon
115
116
117
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
Casper's avatar
Casper committed
118
                module.post_attention_layernorm.variance_epsilon,
119
            )
Casper's avatar
Casper committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            blocks.append(
                LlamaLikeBlock(
                    hidden_size=self.model.config.hidden_size,
                    n_heads=self.model.config.num_attention_heads,
                    n_kv_heads=self.model.config.num_key_value_heads,
                    qkv_layer=qkv,
                    o_proj=module.self_attn.o_proj,
                    mlp=module.mlp,
                    norm_1=norm_1,
                    norm_2=norm_2,
                    dev=device,
                    max_seq_len=self.model.config.max_seq_len,
                )
            )

135
        self.model.model = LlamaLikeModel(
136
137
138
139
140
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
141
        setattr(self.model.model, "blocks", self.model.model.blocks)