mlp.py 2.42 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
2
import torch.nn.functional as F
3
4
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
Haotian Tang's avatar
Haotian Tang committed
5

6
7
try:
    import awq_ext  # with CUDA kernels
Casper's avatar
Casper committed
8

9
10
11
    AWQ_INSTALLED = True
except:
    AWQ_INSTALLED = False
Haotian Tang's avatar
Haotian Tang committed
12

Casper's avatar
Casper committed
13

14
class QuantFusedMLP(nn.Module):
Haotian Tang's avatar
Haotian Tang committed
15
16
17
18
    def __init__(
        self,
        gate_proj,
        down_proj,
19
        up_proj,
Casper's avatar
Casper committed
20
        activation=F.silu,
Haotian Tang's avatar
Haotian Tang committed
21
22
    ):
        super().__init__()
23

Casper's avatar
Casper committed
24
25
26
27
28
29
        self.register_buffer("gate_proj_qweight", gate_proj.qweight)
        self.register_buffer("gate_proj_scales", gate_proj.scales)
        self.register_buffer("gate_proj_qzeros", gate_proj.qzeros)
        self.register_buffer("up_proj_qweight", up_proj.qweight)
        self.register_buffer("up_proj_scales", up_proj.scales)
        self.register_buffer("up_proj_qzeros", up_proj.qzeros)
Haotian Tang's avatar
Haotian Tang committed
30
31
32
33
34
35
36

        self.in_features = gate_proj.in_features
        self.intermediate_size = gate_proj.out_features
        self.out_features = down_proj.out_features
        self.w_bit = gate_proj.w_bit
        self.down_proj = down_proj

Casper Hansen's avatar
Casper Hansen committed
37
        if isinstance(down_proj, WQLinear_GEMV):
38
            self.linear = awq_ext.gemv_forward_cuda
Casper Hansen's avatar
Casper Hansen committed
39
40
            self.group_size = down_proj.group_size
        else:
41
            self.linear = awq_ext.gemm_forward_cuda
Casper Hansen's avatar
Casper Hansen committed
42
43
            self.group_size = 8

44
45
        self.activation = activation

46
    def forward(self, x, routing_weights=None):
Casper Hansen's avatar
Casper Hansen committed
47
        out_shape = x.shape[:-1] + (self.intermediate_size,)
Haotian Tang's avatar
Haotian Tang committed
48
        x = x.reshape(-1, x.shape[-1])
Casper Hansen's avatar
Casper Hansen committed
49
        gate_output = self.linear(
Casper Hansen's avatar
Casper Hansen committed
50
51
52
53
            x,
            self.gate_proj_qweight,
            self.gate_proj_scales,
            self.gate_proj_qzeros,
Casper Hansen's avatar
Casper Hansen committed
54
            self.group_size,
Haotian Tang's avatar
Haotian Tang committed
55
        )
Casper Hansen's avatar
Casper Hansen committed
56
        up_output = self.linear(
Casper Hansen's avatar
Casper Hansen committed
57
58
59
60
            x,
            self.up_proj_qweight,
            self.up_proj_scales,
            self.up_proj_qzeros,
Casper Hansen's avatar
Casper Hansen committed
61
            self.group_size,
Haotian Tang's avatar
Haotian Tang committed
62
        )
63
        x = self.activation(gate_output) * up_output
Casper Hansen's avatar
Casper Hansen committed
64
65
66
        x = x.reshape(out_shape)
        x = self.down_proj(x)

67
68
69
        if routing_weights is not None:
            x = routing_weights * x

70
        return x
Casper's avatar
Casper committed
71

72
73
74

class QuantLlamaMLP(QuantFusedMLP):
    r"""
Casper's avatar
Casper committed
75
    QuantLlamaMLP class kept for backward compatibilty, in the future, users
76
77
    should always use `QuantFusedMLP` class instead.
    """
Casper's avatar
Casper committed
78
79
80

    def __init__(self, gate_proj, down_proj, up_proj):
        super().__init__(gate_proj, down_proj, up_proj)