mm_weight.py 21 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
4
import sgl_kernel
helloyongyang's avatar
helloyongyang committed
5
6
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
7
from lightx2v.utils.envs import *
root's avatar
root committed
8
from loguru import logger
Dongz's avatar
Dongz committed
9

10
11
12
13
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
14

15
16
17
18
19
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

helloyongyang's avatar
helloyongyang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

class MMWeightTemplate(metaclass=ABCMeta):
    def __init__(self, weight_name, bias_name):
        self.weight_name = weight_name
        self.bias_name = bias_name
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

35
36
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
37

gushiqiao's avatar
gushiqiao committed
38
39
    def to_cpu(self, non_blocking=False):
        self.weight = self.weight.to("cpu", non_blocking=non_blocking)
40
41
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
42
43
44
45
46
        if self.bias is not None:
            self.bias = self.bias.to("cpu", non_blocking=non_blocking)

    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
47
48
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
49
50
51
        if self.bias is not None:
            self.bias = self.bias.cuda(non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
52

Dongz's avatar
Dongz committed
53
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
54
55
56
57
58
class MMWeight(MMWeightTemplate):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
59
60
        self.weight = weight_dict[self.weight_name].t()
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
helloyongyang's avatar
helloyongyang committed
61
62
63
64
65
66
67
68
69
70

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
71
72
73
74
75
76
77
78
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        if self.bias is not None:
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
79

Dongz's avatar
Dongz committed
80
@MM_WEIGHT_REGISTER("Default-Force-FP32")
81
class MMWeightForceFP32(MMWeight):
helloyongyang's avatar
helloyongyang committed
82
83
84
85
86
87
88
89
90
91
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
        if self.bias is not None:
            self.bias = self.bias.to(torch.float32)


92
93
94
95
96
97
98
class MMWeightQuantTemplate(MMWeightTemplate):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None

helloyongyang's avatar
helloyongyang committed
99
100
101
    # =========================
    # weight load functions
    # =========================
102
103
104

    def load(self, weight_dict):
        self.load_func(weight_dict)
helloyongyang's avatar
helloyongyang committed
105
106
        if self.weight_need_transpose:
            self.weight = self.weight.t()
107
108

    def load_quantized(self, weight_dict):
109
110
        self.weight = weight_dict[self.weight_name]
        self.weight_scale = weight_dict[self.weight_name.removesuffix(".weight") + ".weight_scale"].float()
111
112

    def load_fp8_perchannel_sym(self, weight_dict):
113
        if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
114
            self.weight = weight_dict[self.weight_name].to(torch.float32)
115
116
117
118
119
120
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
121
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
122
123

    def load_int8_perchannel_sym(self, weight_dict):
124
        if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
125
            self.weight = weight_dict[self.weight_name].to(torch.float32)
126
127
128
129
130
131
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
132
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
133
134

    def load_fp8_perblock128_sym(self, weight_dict):
135
        if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
136
            self.weight = weight_dict[self.weight_name]
137
138
139
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
        else:
            self.load_quantized(weight_dict)
140
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
141
142
143
144
145
146
147
148
149
150
151

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
        x_padded = torch.zeros((deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
152
153
154
    # =========================
    # act quant kernels
    # =========================
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_group_quant_fp8(x, input_tensor_quant, input_tensor_scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0)
        return input_tensor_quant, input_tensor_scale

185
186
187
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
188
189
190
191
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
192
193
194
        if self.bias is not None:
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
195
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
196
197
        return destination

198

Dongz's avatar
Dongz committed
199
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
200
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
201
    """
helloyongyang's avatar
helloyongyang committed
202
203
204
205
206
207
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
208
209
    """

helloyongyang's avatar
helloyongyang committed
210
211
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
212
213
214
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
215
216
217
218
219
220

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
221
222
223

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
helloyongyang's avatar
helloyongyang committed
224
225
226
        return output_tensor


Dongz's avatar
Dongz committed
227
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
228
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
229
    """
helloyongyang's avatar
helloyongyang committed
230
231
232
233
234
235
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
236
237
    """

helloyongyang's avatar
helloyongyang committed
238
239
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
240
241
242
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
243
244
245
246
247
248

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
249
250
251

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
helloyongyang's avatar
helloyongyang committed
252
253
254
        return output_tensor


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
274
        output_tensor = Q8F.linear.fp8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, out_dtype=torch.bfloat16)
275
276
277
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
278
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
279
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
280
    """
281
282
283
284
285
286
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
287
288
    """

289
290
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
291
292
293
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
294

295
296
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
297
        output_tensor = Q8F.linear.q8_linear(input_tensor_quant, self.weight, self.bias.float(), input_tensor_scale, self.weight_scale, fuse_gelu=False, out_dtype=torch.bfloat16)
298
299
300
        return output_tensor.squeeze(0)


301
302
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
303
    """
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
        Out : torch.Size([1024, 4096]), torch.bfloat16
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        deep_gemm.gemm_fp8_fp8_bf16_nt((input_tensor_quant, input_tensor_scale), (self.weight, self.weight_scale), output_tensor)
        if self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        deep_gemm.gemm_fp8_fp8_bf16_nt((input_tensor_quant, input_tensor_scale), (self.weight, self.weight_scale), output_tensor)
        if self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
376
377
378
379

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
380
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
381
382
    """

383
384
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
385
386
387
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
388

389
390
391
392
393
394
395
396
397
398
399
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(output_tensor, input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, self.bias)
        return output_tensor


helloyongyang's avatar
helloyongyang committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = sgl_kernel.fp8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, bias=self.bias)
        return output_tensor


423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
439
440

    def apply(self, input_tensor):
441
442
443
444
445
446
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = sgl_kernel.fp8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
447
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = sgl_kernel.int8_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, torch.bfloat16, self.bias)
        return output_tensor
472
473


Dongz's avatar
Dongz committed
474
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
475
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
476
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
477
478
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
479
480
    }

Dongz's avatar
Dongz committed
481
482
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
483
484
485
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
486
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
487
488

    weight_dict = {
Dongz's avatar
Dongz committed
489
490
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
491
492
    }

Dongz's avatar
Dongz committed
493
494
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
495
496
497
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
498
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
499
500

    weight_dict = {
Dongz's avatar
Dongz committed
501
502
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
503
504
    }

Dongz's avatar
Dongz committed
505
506
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
507
508
509
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
510
    logger.info(output_tensor.shape)