baichuan.py 4.39 KB
Newer Older
Aoyu's avatar
Aoyu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
    LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm

class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "BaichuanLayer"
    max_new_tokens_key = "model_max_length"

    @staticmethod
    def fuse_layers(model):
        fuser = BaichuanFuser(model)
        fuser.fuse_transformer()

    @staticmethod
    def get_model_layers(model):
        return model.model.layers
    
    @staticmethod
    def get_act_for_scaling(module):
        return dict(
            is_scalable=False
        )
    
    @staticmethod
    def move_embed(model, device: str):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    
    @staticmethod
    # def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
    def get_layers_for_scaling(module, input_feat, module_kwargs):
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.input_layernorm,
            layers=[module.self_attn.W_pack],
            inp=input_feat['self_attn.W_pack'],
            module2inspect=module.self_attn, kwargs=module_kwargs,
        ))

        # # attention out
        # # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        # if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
        #     layers.append(dict(
        #         prev_op=module.self_attn.v_proj,
        #         layers=[module.self_attn.o_proj],
        #         inp=input_feat['self_attn.o_proj'],
        #     ))

        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        layers.append(dict(
            prev_op=module.self_attn.W_pack,
            layers=[module.self_attn.o_proj],
            inp=input_feat['self_attn.o_proj'],
        ))
        
        # linear 1
        layers.append(dict(
            prev_op=module.post_attention_layernorm,
            layers=[module.mlp.gate_proj, module.mlp.up_proj],
            inp=input_feat['mlp.gate_proj'],
            module2inspect=module.mlp,
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.mlp.up_proj,
            layers=[module.mlp.down_proj],
            inp=input_feat['mlp.down_proj'],
        ))

        return layers


class BaichuanFuser:
    def __init__(self, model):
        self.model = model

        self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
            (name, module) for name, module in self.model.named_modules()
            if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
        ]
    
    def fuse_transformer(self):
        blocks = []

        for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
            device = next(iter(module.state_dict().values())).device
            # qkv = fuse_qkv(
            #     module,
            #     module.self_attn.q_proj,
            #     module.self_attn.k_proj,
            #     module.self_attn.v_proj
            # )
            qkv = module.self_attn.W_pack
            norm_1 = FasterTransformerRMSNorm(
                module.input_layernorm.weight,
                module.input_layernorm.epsilon
            )
            norm_2 = FasterTransformerRMSNorm(
                module.post_attention_layernorm.weight,
                module.post_attention_layernorm.epsilon
            )
            blocks.append(LlamaLikeBlock(
                hidden_size=self.model.config.hidden_size,
                n_heads=self.model.config.num_attention_heads,
                n_kv_heads=self.model.config.num_attention_heads,
                qkv_layer=qkv,
                o_proj=module.self_attn.o_proj,
117
                mlp=module.mlp,
Aoyu's avatar
Aoyu committed
118
119
120
121
122
123
124
125
126
127
128
129
130
                norm_1=norm_1,
                norm_2=norm_2,
                dev=device,
                max_seq_len=self.model.config.max_new_tokens,
                use_alibi=True
            ))
        
        self.model.model = LlamaLikeModel(
            self.model.config.vocab_size,
            blocks,
            self.model.model.embed_tokens,
            self.model.model.norm,
        )
131
132

        setattr(self.model.model, "blocks", self.model.model.blocks)