mm_weight.py 10 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
Dongz's avatar
Dongz committed
6

7
8
9
10
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class MMWeightTemplate(metaclass=ABCMeta):
    def __init__(self, weight_name, bias_name):
        self.weight_name = weight_name
        self.bias_name = bias_name
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

    def set_config(self, config=None):
        if config is not None:
            self.config = config

gushiqiao's avatar
gushiqiao committed
31
32
33
34
35
36
37
38
39
40
    def to_cpu(self, non_blocking=False):
        self.weight = self.weight.to("cpu", non_blocking=non_blocking)
        if self.bias is not None:
            self.bias = self.bias.to("cpu", non_blocking=non_blocking)

    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
        if self.bias is not None:
            self.bias = self.bias.cuda(non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
41

Dongz's avatar
Dongz committed
42
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class MMWeight(MMWeightTemplate):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
        self.weight = weight_dict[self.weight_name].t().cuda()
        self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)


Dongz's avatar
Dongz committed
61
@MM_WEIGHT_REGISTER("Default-Force-FP32")
62
class MMWeightForceFP32(MMWeight):
helloyongyang's avatar
helloyongyang committed
63
64
65
66
67
68
69
70
71
72
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
        if self.bias is not None:
            self.bias = self.bias.to(torch.float32)


Dongz's avatar
Dongz committed
73
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
helloyongyang's avatar
helloyongyang committed
74
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
Dongz's avatar
Dongz committed
75
    """
helloyongyang's avatar
helloyongyang committed
76
77
78
79
80
81
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
82
83
    """

helloyongyang's avatar
helloyongyang committed
84
85
86
87
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
Dongz's avatar
Dongz committed
88
        if self.config.get("weight_auto_quant", True):
helloyongyang's avatar
helloyongyang committed
89
            self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
gushiqiao's avatar
gushiqiao committed
90
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
helloyongyang's avatar
helloyongyang committed
91
92
93
94
95
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda()
            self.weight_scale = self.weight_scale.to(torch.float32).cuda()
        else:
            self.weight = weight_dict[self.weight_name].t().cuda()
Dongz's avatar
Dongz committed
96
            self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
helloyongyang's avatar
helloyongyang committed
97
98
99
100
101
102
103
104
105
106
107
108
        self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
        torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
        return output_tensor


Dongz's avatar
Dongz committed
109
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
110
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
Dongz's avatar
Dongz committed
111
    """
helloyongyang's avatar
helloyongyang committed
112
113
114
115
116
117
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
118
119
    """

helloyongyang's avatar
helloyongyang committed
120
121
122
123
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
Dongz's avatar
Dongz committed
124
        if self.config.get("weight_auto_quant", True):
helloyongyang's avatar
helloyongyang committed
125
            self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
gushiqiao's avatar
gushiqiao committed
126
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
helloyongyang's avatar
helloyongyang committed
127
128
129
130
131
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8).t().cuda()
            self.weight_scale = self.weight_scale.to(torch.float32).cuda()
        else:
            self.weight = weight_dict[self.weight_name].t().cuda()
Dongz's avatar
Dongz committed
132
            self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
helloyongyang's avatar
helloyongyang committed
133
134
135
136
137
138
139
140
141
142
143
144
        self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
        torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
        return output_tensor


Dongz's avatar
Dongz committed
145
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
146
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
Dongz's avatar
Dongz committed
147
    """
148
149
150
151
152
153
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
154
155
    """

156
157
158
159
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
Dongz's avatar
Dongz committed
160
        if self.config.get("weight_auto_quant", True):
161
            self.weight = weight_dict[self.weight_name].cuda()
gushiqiao's avatar
gushiqiao committed
162
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
163
164
165
166
167
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.weight = weight_dict[self.weight_name].cuda()
Dongz's avatar
Dongz committed
168
            self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
169
170
171
172
173
174
175
176
        self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None

    def apply(self, input_tensor, act=None):
        qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
        output_tensor = Q8F.linear.q8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
177
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
178
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
Dongz's avatar
Dongz committed
179
    """
180
181
182
183
184
185
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
186
187
    """

188
189
190
191
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
Dongz's avatar
Dongz committed
192
        if self.config.get("weight_auto_quant", True):
193
            self.weight = weight_dict[self.weight_name].cuda()
gushiqiao's avatar
gushiqiao committed
194
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
195
196
197
198
199
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.weight = weight_dict[self.weight_name].cuda()
Dongz's avatar
Dongz committed
200
            self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
201
202
203
204
205
206
207
208
        self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None

    def apply(self, input_tensor):
        qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
        output_tensor = Q8F.linear.fp8_linear(qinput, self.weight, self.bias, x_scale, self.weight_scale, out_dtype=torch.bfloat16)
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
209
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
210
    weight_dict = {
Dongz's avatar
Dongz committed
211
212
213
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
214
215
    }

Dongz's avatar
Dongz committed
216
217
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
218
219
220
221
222
223
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
    print(output_tensor.shape)

    weight_dict = {
Dongz's avatar
Dongz committed
224
225
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
226
227
    }

Dongz's avatar
Dongz committed
228
229
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
230
231
232
233
234
235
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
    print(output_tensor.shape)

    weight_dict = {
Dongz's avatar
Dongz committed
236
237
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
238
239
    }

Dongz's avatar
Dongz committed
240
241
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
242
243
244
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
Dongz's avatar
Dongz committed
245
    print(output_tensor.shape)