mpt.py 3.76 KB
Newer Older
Casper's avatar
Casper committed
1
from .base import BaseAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
2
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP, MptAttention, LayerNorm
Casper's avatar
Casper committed
3
4
5

class MptAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MPTBlock"
6
    max_new_tokens_key = "max_seq_len"
Casper's avatar
Casper committed
7

8
    @staticmethod
9
    def fuse_layers(model: MptForCausalLM, quant_config:dict):
Casper Hansen's avatar
Casper Hansen committed
10
        fuser = MptFuser(model)
Casper Hansen's avatar
Casper Hansen committed
11
12
        fuser.fuse_attention()
        fuser.fuse_layernorm()
13

14
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
15
    def get_model_layers(model: MptForCausalLM):
Casper's avatar
Casper committed
16
17
        return model.transformer.blocks
    
18
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
19
    def get_act_for_scaling(module: MptBlock):
20
21
22
23
24
25
26
27
        return dict(
            is_scalable=True,
            scale_name="ffn.act",
            scale_layer=module.ffn.act,
            scale_shape=module.ffn.up_proj.out_features
        )
    
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
28
    def move_embed(model: MptForCausalLM, device: str):
29
30
31
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.emb_drop = model.transformer.emb_drop.to(device)
    
32
    @staticmethod
Casper Hansen's avatar
Casper Hansen committed
33
    def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
Casper's avatar
Casper committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        layers = []

        # attention input
        layers.append(dict(
            prev_op=module.norm_1,
            layers=[module.attn.Wqkv],
            inp=input_feat['attn.Wqkv'],
            module2inspect=module.attn,
            kwargs=module_kwargs
        ))

        # attention output
        layers.append(dict(
            prev_op=module.attn.Wqkv,
            layers=[module.attn.out_proj],
            inp=input_feat['attn.out_proj']
        ))

        # linear 1
        layers.append(dict(
Casper Hansen's avatar
Casper Hansen committed
54
            prev_op=module.norm_2,
Casper's avatar
Casper committed
55
56
57
58
59
60
61
62
63
64
65
66
            layers=[module.ffn.up_proj],
            inp=input_feat['ffn.up_proj'],
            module2inspect=module.ffn
        ))

        # linear 2
        layers.append(dict(
            prev_op=module.ffn.act,
            layers=[module.ffn.down_proj],
            inp=input_feat['ffn.down_proj']
        ))

Casper Hansen's avatar
Casper Hansen committed
67
68
        return layers

Casper Hansen's avatar
Casper Hansen committed
69
70
import torch
import xformers
Casper Hansen's avatar
Casper Hansen committed
71
72
from typing import List, Tuple
from awq.utils.utils import set_module_name
Casper Hansen's avatar
Casper Hansen committed
73
74
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused
Casper Hansen's avatar
Casper Hansen committed
75
76
77
78
79

class MptFuser:
    def __init__(self, model):
        self.model = model

Casper Hansen's avatar
Casper Hansen committed
80
81
82
83
84
85
86
87
88
89
        self.attention_modules: List[Tuple[str, MptAttention]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MptAttention)
        ]

        self.layernorm_modules: List[Tuple[str, LayerNorm]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, LayerNorm)
        ]

Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
94
95
        self.mlp_modules: List[Tuple[str, MptMLP]] = [
            (name, module) for name, module in self.model.named_modules()
            if isinstance(module, MptMLP)
        ]
    
    def fuse_attention(self):
Casper Hansen's avatar
Casper Hansen committed
96
97
98
99
100
101
102
103
104
105
106
        for name, qkv_layer in self.attention_modules:
            attn = QuantAttentionFused(
                qkv_layer.hidden_size,
                qkv_layer.n_heads,
                qkv_layer, 
                qkv_layer.out_proj,
                next(iter(qkv_layer.state_dict().values())).device,
                self.model.config.max_new_tokens,
                use_alibi=True
            )
            set_module_name(self.model, name, attn)
Casper Hansen's avatar
Casper Hansen committed
107
108

    def fuse_layernorm(self):
Casper Hansen's avatar
Casper Hansen committed
109
110
111
112
113
114
115
116
117
118
        xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
        for name, module in self.layernorm_modules:
            norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)

            # copy weights and bias
            with torch.no_grad():
                norm.weight = module.weight
                norm.bias = module.bias

            set_module_name(self.model, name, norm)
Casper Hansen's avatar
Casper Hansen committed
119
120

    def fuse_mlp(self):
Casper Hansen's avatar
Casper Hansen committed
121
        pass