q_linear.py 5.59 KB
Newer Older
1
2
3
import torch
import torch.nn as nn

gushiqiao's avatar
gushiqiao committed
4
5
6
7
try:
    from vllm import _custom_ops as ops
except ModuleNotFoundError:
    ops = None
8

gushiqiao's avatar
gushiqiao committed
9
10
11
12
13
14
15
try:
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None


class VllmQuantLinearInt8(nn.Module):
gushiqiao's avatar
gushiqiao committed
16
    def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
17
18
19
20
21
22
23
24
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))

        if bias:
gushiqiao's avatar
gushiqiao committed
25
            self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
26
27
28
29
30
31
32
        else:
            self.register_buffer("bias", None)

    def act_quant_func(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    def forward(self, input_tensor):
        input_tensor = input_tensor.squeeze(0)
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight.t(),
            input_tensor_scale,
            self.weight_scale.float(),
            self.bias,
        )
        return output_tensor.unsqueeze(0)

gushiqiao's avatar
gushiqiao committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        def maybe_cast(t):
            if t is not None and t.device != fn(t).device:
                return fn(t)
            return t

        self.weight = maybe_cast(self.weight)
        self.weight_scale = maybe_cast(self.weight_scale)
        self.bias = maybe_cast(self.bias)
        return self

65

gushiqiao's avatar
gushiqiao committed
66
class VllmQuantLinearFp8(nn.Module):
gushiqiao's avatar
gushiqiao committed
67
    def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
68
69
70
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
gushiqiao's avatar
FIX  
gushiqiao committed
71
72
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
        self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
73
        if bias:
gushiqiao's avatar
gushiqiao committed
74
            self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        else:
            self.register_buffer("bias", None)

    def act_quant_func(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def forward(self, input_tensor):
        input_tensor = input_tensor.squeeze(0)
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
91
            input_tensor_quant,
92
            self.weight.t(),
93
94
            input_tensor_scale,
            self.weight_scale.float(),
95
            self.bias,
96
        )
gushiqiao's avatar
gushiqiao committed
97

98
        return output_tensor.unsqueeze(0)
gushiqiao's avatar
gushiqiao committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        def maybe_cast(t):
            if t is not None and t.device != fn(t).device:
                return fn(t)
            return t

        self.weight = maybe_cast(self.weight)
        self.weight_scale = maybe_cast(self.weight_scale)
        self.bias = maybe_cast(self.bias)
        return self
gushiqiao's avatar
gushiqiao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154


class TorchaoQuantLinearInt8(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))

        if bias:
            self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
        else:
            self.register_buffer("bias", None)

    def act_quant_func(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale

    def forward(self, input_tensor):
        input_tensor = input_tensor.squeeze(0)
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor.unsqueeze(0)

    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        def maybe_cast(t):
            if t is not None and t.device != fn(t).device:
                return fn(t)
            return t

        self.weight = maybe_cast(self.weight)
        self.weight_scale = maybe_cast(self.weight_scale)
        self.bias = maybe_cast(self.bias)
        return self