mlp.py 2.88 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
4
import torch.nn.functional as F
Casper Hansen's avatar
Casper Hansen committed
5
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Haotian Tang's avatar
Haotian Tang committed
6

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class QuantMPTMLP(nn.Module):
    def __init__(
        self,
        up_proj,
        act,
        down_proj
    ):
        super().__init__()
        self.register_buffer('up_proj_qweight', up_proj.qweight)
        self.register_buffer('up_proj_scales', up_proj.scales)
        self.register_buffer('up_proj_qzeros', up_proj.qzeros)

        self.up_proj = up_proj
        self.act = act
        self.down_proj = down_proj
Casper Hansen's avatar
Casper Hansen committed
22
23
24
25
26
27
28

        if isinstance(down_proj, WQLinear_GEMV):
            self.linear = awq_inference_engine.gemv_forward_cuda
            self.group_size = down_proj.group_size
        else:
            self.linear = awq_inference_engine.gemm_forward_cuda
            self.group_size = 8
29
30
31
    
    def forward(self, x: torch.Tensor):
        x = x.reshape(-1, x.shape[-1])
Casper Hansen's avatar
Casper Hansen committed
32
        x = self.linear(
Casper Hansen's avatar
Casper Hansen committed
33
34
35
36
            x, 
            self.up_proj_qweight, 
            self.up_proj_scales, 
            self.up_proj_qzeros, 
Casper Hansen's avatar
Casper Hansen committed
37
            self.group_size
Casper Hansen's avatar
Casper Hansen committed
38
        )
39
40

        return self.down_proj(self.act(x))
Haotian Tang's avatar
Haotian Tang committed
41
42
43
44
45
46
47

class QuantLlamaMLP(nn.Module):

    def __init__(
        self,
        gate_proj,
        down_proj,
Casper Hansen's avatar
Casper Hansen committed
48
        up_proj
Haotian Tang's avatar
Haotian Tang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    ):
        super().__init__()
        self.register_buffer('gate_proj_qweight', gate_proj.qweight)
        self.register_buffer('gate_proj_scales', gate_proj.scales)
        self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
        self.register_buffer('up_proj_qweight', up_proj.qweight)
        self.register_buffer('up_proj_scales', up_proj.scales)
        self.register_buffer('up_proj_qzeros', up_proj.qzeros)

        self.in_features = gate_proj.in_features
        self.intermediate_size = gate_proj.out_features
        self.out_features = down_proj.out_features
        self.w_bit = gate_proj.w_bit
        self.down_proj = down_proj

Casper Hansen's avatar
Casper Hansen committed
64
65
66
67
68
69
70
        if isinstance(down_proj, WQLinear_GEMV):
            self.linear = awq_inference_engine.gemv_forward_cuda
            self.group_size = down_proj.group_size
        else:
            self.linear = awq_inference_engine.gemm_forward_cuda
            self.group_size = 8

Haotian Tang's avatar
Haotian Tang committed
71
    def forward(self, x):
Casper Hansen's avatar
Casper Hansen committed
72
        out_shape = x.shape[:-1] + (self.intermediate_size,)
Haotian Tang's avatar
Haotian Tang committed
73
        x = x.reshape(-1, x.shape[-1])
Casper Hansen's avatar
Casper Hansen committed
74
        gate_output = self.linear(
Casper Hansen's avatar
Casper Hansen committed
75
76
77
78
            x,
            self.gate_proj_qweight,
            self.gate_proj_scales,
            self.gate_proj_qzeros,
Casper Hansen's avatar
Casper Hansen committed
79
            self.group_size,
Haotian Tang's avatar
Haotian Tang committed
80
        )
Casper Hansen's avatar
Casper Hansen committed
81
        up_output = self.linear(
Casper Hansen's avatar
Casper Hansen committed
82
83
84
85
            x,
            self.up_proj_qweight,
            self.up_proj_scales,
            self.up_proj_qzeros,
Casper Hansen's avatar
Casper Hansen committed
86
            self.group_size,
Haotian Tang's avatar
Haotian Tang committed
87
        )
Casper Hansen's avatar
Casper Hansen committed
88
89
90
91
92
        x = F.silu(gate_output) * up_output
        x = x.reshape(out_shape)
        x = self.down_proj(x)

        return x