"vscode:/vscode.git/clone" did not exist on "7e33a017c086d2dbde3be0a546d40818bb1e9c16"
falcon.py 4.25 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
from .base import BaseAWQForCausalLM
Casper's avatar
Casper committed
2
3
4
5
6
7
from transformers.models.falcon.modeling_falcon import (
    FalconDecoderLayer as OldFalconDecoderLayer,
    FalconForCausalLM,
    FalconAttention,
)

Casper Hansen's avatar
Casper Hansen committed
8
9
10
11

class FalconAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "FalconDecoderLayer"

Casper Hansen's avatar
Casper Hansen committed
12
    @staticmethod
Casper's avatar
Casper committed
13
    def fuse_layers(model: FalconForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
14
        fuser = FalconFuser(model)
15
16
17
18

        # TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
        if model.config.num_attention_heads == 71:
            fuser.fuse_transformer()
Casper Hansen's avatar
Casper Hansen committed
19

Casper Hansen's avatar
Casper Hansen committed
20
21
22
    @staticmethod
    def get_model_layers(model: FalconForCausalLM):
        return model.transformer.h
Casper's avatar
Casper committed
23

Casper Hansen's avatar
Casper Hansen committed
24
    @staticmethod
25
    def get_act_for_scaling(module: OldFalconDecoderLayer):
Casper Hansen's avatar
Casper Hansen committed
26
27
28
29
        return dict(
            is_scalable=True,
            scale_name="mlp.act",
            scale_layer=module.mlp.act,
Casper's avatar
Casper committed
30
            scale_shape=module.mlp.dense_h_to_4h.out_features,
Casper Hansen's avatar
Casper Hansen committed
31
        )
Casper's avatar
Casper committed
32

Casper Hansen's avatar
Casper Hansen committed
33
34
35
    @staticmethod
    def move_embed(model: FalconForCausalLM, device):
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
Casper's avatar
Casper committed
36

Casper Hansen's avatar
Casper Hansen committed
37
    @staticmethod
Casper's avatar
Casper committed
38
39
40
    def get_layers_for_scaling(
        module: OldFalconDecoderLayer, input_feat, module_kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
41
        layers = []
Casper's avatar
Casper committed
42

Casper Hansen's avatar
Casper Hansen committed
43
44
45
        # Falcon 7B (older architecture)
        if module.config.num_attention_heads == 71:
            # linear 1 + attention
Casper's avatar
Casper committed
46
47
48
49
50
51
52
53
54
55
56
57
            layers.append(
                dict(
                    prev_op=module.input_layernorm,
                    layers=[
                        module.mlp.dense_h_to_4h,
                        module.self_attention.query_key_value,
                    ],
                    inp=input_feat["self_attention.query_key_value"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
Casper Hansen's avatar
Casper Hansen committed
58
59
60
61

        # Falcon 40B (newer architecture)
        else:
            # linear 1 + attention
Casper's avatar
Casper committed
62
63
64
65
66
67
68
69
70
            layers.append(
                dict(
                    prev_op=module.ln_attn,
                    layers=[module.self_attention.query_key_value],
                    inp=input_feat["self_attention.query_key_value"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
Casper Hansen's avatar
Casper Hansen committed
71
72

            # linear 2
Casper's avatar
Casper committed
73
74
75
76
77
78
79
80
81
            layers.append(
                dict(
                    prev_op=module.ln_mlp,
                    layers=[module.mlp.dense_h_to_4h],
                    inp=input_feat["mlp.dense_h_to_4h"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
Casper Hansen's avatar
Casper Hansen committed
82

Casper Hansen's avatar
Casper Hansen committed
83
84
        return layers

Casper's avatar
Casper committed
85

86
87
from awq.modules.fused.model import FalconModel
from awq.modules.fused.block import FalconDecoderLayer
Casper Hansen's avatar
Casper Hansen committed
88

Casper's avatar
Casper committed
89

Casper Hansen's avatar
Casper Hansen committed
90
class FalconFuser:
91
    def __init__(self, model: FalconForCausalLM):
Casper Hansen's avatar
Casper Hansen committed
92
        self.model = model
Casper's avatar
Casper committed
93

94
95
    def fuse_transformer(self):
        blocks = []
Casper Hansen's avatar
Casper Hansen committed
96

97
98
99
100
101
102
103
104
105
106
107
108
        module: OldFalconDecoderLayer
        for module in self.model.transformer.h:
            if module.config.num_attention_heads == 71:
                input_layernorm = module.input_layernorm
                ln_attn = None
                ln_mlp = None
                new_decoder_arch = False
            else:
                input_layernorm = None
                ln_attn = module.ln_attn
                ln_mlp = module.ln_mlp
                new_decoder_arch = True
Casper's avatar
Casper committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

            blocks.append(
                FalconDecoderLayer(
                    hidden_size=module.config.hidden_size,
                    n_heads=module.config.num_attention_heads,
                    qkv_layer=module.self_attention.query_key_value,
                    o_proj=module.self_attention.dense,
                    mlp=module.mlp,
                    dev=next(iter(module.state_dict().values())).device,
                    max_seq_len=self.model.config.max_seq_len,
                    input_layernorm=input_layernorm,
                    ln_attn=ln_attn,
                    ln_mlp=ln_mlp,
                    new_decoder_arch=new_decoder_arch,
                )
            )
Casper Hansen's avatar
Casper Hansen committed
125

126
127
128
129
130
        self.model.transformer = FalconModel(
            self.model.config.vocab_size,
            blocks,
            self.model.transformer.word_embeddings,
            self.model.transformer.ln_f,
131
132
        )

Casper's avatar
Casper committed
133
        setattr(self.model.transformer, "blocks", self.model.transformer.blocks)