mm_weight.py 30.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
2
3

import torch
root's avatar
root committed
4
from loguru import logger
Dongz's avatar
Dongz committed
5

PengGao's avatar
PengGao committed
6
7
8
9
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER

gushiqiao's avatar
gushiqiao committed
10
11
12
13
14
15
16
17
18
19
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

20
21
22
23
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
24

25
26
27
28
29
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
30
try:
Wq-dd's avatar
Wq-dd committed
31
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
gushiqiao's avatar
gushiqiao committed
32
33
34
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

35
36
37
38
39
try:
    import gguf
except ImportError:
    gguf = None

40
41
42
43
try:
    import marlin_cuda_quant
except ModuleNotFoundError:
    marlin_cuda_quant = None
helloyongyang's avatar
helloyongyang committed
44
45

class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
46
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
47
48
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
49
50
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
51
52
53
54
55
56
57
58
59
60
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

61
62
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
63

gushiqiao's avatar
gushiqiao committed
64
65
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
66
67
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
68
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
69
70
            self.bias = self.bias.cuda(non_blocking=non_blocking)

71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
85

Dongz's avatar
Dongz committed
86
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
87
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
88
89
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
90
91

    def load(self, weight_dict):
92
        self.weight = weight_dict[self.weight_name].t()
Xinchi Huang's avatar
Xinchi Huang committed
93
        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
94
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
Xinchi Huang's avatar
Xinchi Huang committed
95
        self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
helloyongyang's avatar
helloyongyang committed
96
97
98
99
100
101
102
103
104
105

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
106
107
108
109
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
110
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
111
112
113
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
114

Dongz's avatar
Dongz committed
115
@MM_WEIGHT_REGISTER("Default-Force-FP32")
116
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
117
118
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
119
120
121
122

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
123
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
124
125
126
            self.bias = self.bias.to(torch.float32)


127
class MMWeightQuantTemplate(MMWeightTemplate):
128
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
129
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
130
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
131
132
133
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
134
135
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
136
        self.infer_dtype = GET_DTYPE()
137

helloyongyang's avatar
helloyongyang committed
138
139
140
    # =========================
    # weight load functions
    # =========================
141

142
143
144
145
146
    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
147
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
148
149
150
151
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
152
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
153

helloyongyang's avatar
helloyongyang committed
154
155
        if self.weight_need_transpose:
            self.weight = self.weight.t()
156

157
158
159
160
161
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()
gushiqiao's avatar
Fix  
gushiqiao committed
162
                self.pinned_weight = self.pinned_weight.t()
163
164

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
165
        attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
166
167
168
169
170
171
172
173
174
175
176
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

177
    def load_quantized(self, weight_dict):
178
        self.weight = weight_dict[self.weight_name]
179
180
181
182
        self.weight_scale = weight_dict[self.weight_scale_name].float()

        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
        self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
183
184

    def load_fp8_perchannel_sym(self, weight_dict):
185
        if self.config.get("weight_auto_quant", False):
186
            self.weight = weight_dict[self.weight_name].to(torch.float32)
187
188
189
190
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
191
192
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
193
194
        else:
            self.load_quantized(weight_dict)
195
196
197
198
199
200

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
201
202

    def load_int8_perchannel_sym(self, weight_dict):
203
        if self.config.get("weight_auto_quant", False):
204
            self.weight = weight_dict[self.weight_name].to(torch.float32)
205
206
207
208
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
209
210
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
211
212
        else:
            self.load_quantized(weight_dict)
213
214
215
216
217
218

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
219
220

    def load_fp8_perblock128_sym(self, weight_dict):
221
        if self.config.get("weight_auto_quant", False):
222
            self.weight = weight_dict[self.weight_name]
223
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
224
225
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
226
227
        else:
            self.load_quantized(weight_dict)
228
229
230
231
232
233

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
234
235
236
237

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
238
239
240
241
242
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
243
244
245
246
247
248
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
249
250
251
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
252
253
254
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
282
283
284
285
286
287
288
289
290
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
291
292
        return input_tensor_quant, input_tensor_scale

293
294
295
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
296
297
298
299
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
300
        if hasattr(self, "bias") and self.bias is not None:
301
302
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
303
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
304
305
        return destination

306

Dongz's avatar
Dongz committed
307
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
308
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
309
    """
helloyongyang's avatar
helloyongyang committed
310
311
312
313
314
315
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
316
317
    """

318
319
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
320
321
322
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
323
324
325
326
327
328

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
329
330

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
331
332
333
334
335
336
337
338
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
339
340
341
        return output_tensor


Dongz's avatar
Dongz committed
342
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
343
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
344
    """
helloyongyang's avatar
helloyongyang committed
345
346
347
348
349
350
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
351
352
    """

353
354
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
355
356
357
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
358
359
360
361
362
363

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
364
365

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
366
367
368
369
370
371
372
373
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
374
375
376
        return output_tensor


377
378
379
380
381
382
383
384
385
386
387
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

388
389
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
390
391
392
393
394
395
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
396
397
398
399
400
401
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
402
            out_dtype=self.infer_dtype,
403
        )
404
405
406
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
407
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
408
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
409
    """
410
411
412
413
414
415
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
416
417
    """

418
419
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
420
421
422
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
423

424
425
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
426
427
428
429
430
431
432
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
433
            out_dtype=self.infer_dtype,
434
        )
435
436
437
        return output_tensor.squeeze(0)


438
439
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
440
    """
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
457
        Out : torch.Size([1024, 4096]), self.infer_dtype
458
459
    """

460
461
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
462
463
464
465
466
467
468
469
470
471
472
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
473
474
475
476
477
478
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

494
495
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
496
497
498
499
500
501
502
503
504
505
506
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
507
508
509
510
511
512
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
513
514
515
516
517
518
519
520
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
521
522
523
524

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
525
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
526
527
    """

528
529
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
530
531
532
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
533

534
535
536
537
538
539
540
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
541
542
543
544
545
546
547
548
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
549
550
551
        return output_tensor


helloyongyang's avatar
helloyongyang committed
552
553
554
555
556
557
558
559
560
561
562
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

563
564
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
565
566
567
568
569
570
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
571
572
573
574
575
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
576
            self.infer_dtype,
577
578
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
579
580
581
        return output_tensor


582
583
584
585
586
587
588
589
590
591
592
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

593
594
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
595
596
597
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
598
599

    def apply(self, input_tensor):
600
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
601
602
603
604
605
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
606
            self.infer_dtype,
607
608
            bias=self.bias,
        )
609
610
611
612
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
613
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
614
615
616
617
618
619
620
621
622
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

623
624
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
625
626
627
628
629
630
631
632
633
634
635
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
636
637
638
639
640
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
641
            self.infer_dtype,
642
643
            self.bias,
        )
644
        return output_tensor
645
646


gushiqiao's avatar
gushiqiao committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
667
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
668
669
670
671
672
673
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
    TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

    def dequantize_func(self):
        # TODO: implement dequantize_func
        pass


@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

690
691
692
693
694
695
696
697
698
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
    """
    Name: "W-int4-group128-sym-Marlin

    Quant int4 x FP16:
        Weight: int4 pergroup sym
        Kernel: Marlin
    """
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_quantized

    def load(self, weight_dict):
        assert not self.lazy_load
        self.load_func(weight_dict)
        self.workspace = weight_dict[f"{self.weight_name}_workspace"]
        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
    
    def apply(self, input_tensor):
        output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
        marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
        if hasattr(self, "bias") and self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor
    
Dongz's avatar
Dongz committed
721
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
722
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
723
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
724
725
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
726
727
    }

Dongz's avatar
Dongz committed
728
729
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
730
731
732
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
733
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
734
735

    weight_dict = {
Dongz's avatar
Dongz committed
736
737
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
738
739
    }

Dongz's avatar
Dongz committed
740
741
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
742
743
744
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
745
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
746
747

    weight_dict = {
Dongz's avatar
Dongz committed
748
749
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
750
751
    }

Dongz's avatar
Dongz committed
752
753
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
754
755
756
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
757
    logger.info(output_tensor.shape)