falcon.py 3.22 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
from .base import BaseAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
2
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM, FalconAttention
Casper Hansen's avatar
Casper Hansen committed
3
4
5
6

class FalconAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "FalconDecoderLayer"

Casper Hansen's avatar
Casper Hansen committed
7
8
9
10
    @staticmethod
    def fuse_layers(model: FalconForCausalLM, quant_config:dict):
        fuser = FalconFuser(model)

Casper Hansen's avatar
Casper Hansen committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    @staticmethod
    def get_model_layers(model: FalconForCausalLM):
        return model.transformer.h
    
    @staticmethod
    def get_act_for_scaling(module: FalconDecoderLayer):
        return dict(
            is_scalable=True,
            scale_name="mlp.act",
            scale_layer=module.mlp.act,
            scale_shape=module.mlp.dense_h_to_4h.out_features
        )
    
    @staticmethod
    def move_embed(model: FalconForCausalLM, device):
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
    
    @staticmethod
    def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs):
        layers = []
        
        # Falcon 7B (older architecture)
        if module.config.num_attention_heads == 71:
            # linear 1 + attention
            layers.append(dict(
                prev_op=module.input_layernorm,
                layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
                inp=input_feat['self_attention.query_key_value'],
                module2inspect=module,
                kwargs=module_kwargs,
            ))

        # Falcon 40B (newer architecture)
        else:
            # linear 1 + attention
            layers.append(dict(
                prev_op=module.ln_attn,
                layers=[module.self_attention.query_key_value],
                inp=input_feat['self_attention.query_key_value'],
                module2inspect=module,
                kwargs=module_kwargs,
            ))

            # linear 2
            layers.append(dict(
                prev_op=module.ln_mlp,
                layers=[module.mlp.dense_h_to_4h],
                inp=input_feat['mlp.dense_h_to_4h'],
                module2inspect=module,
                kwargs=module_kwargs,
            ))

Casper Hansen's avatar
Casper Hansen committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        return layers

import torch
from torch.nn import LayerNorm
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.attn import QuantAttentionFused

class FalconFuser:
    def __init__(self, model):
        self.model = model

        self.attention_modules: List[Tuple[str, FalconAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, FalconAttention)
        ]

        self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LayerNorm)
        ]
    
    def fuse_attention(self):
        for name, qkv_layer in self.attention_modules:
            attn = QuantAttentionFused(
                qkv_layer.hidden_size,
                qkv_layer.num_heads,
                qkv_layer, 
                qkv_layer.dense,
                next(iter(qkv_layer.state_dict().values())).device,
                self.model.config.max_new_tokens
            )
Casper Hansen's avatar
Casper Hansen committed
95
            set_module_name(self.model, name, attn)