mm_weight.py 44.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
2
3

import torch
Dongz's avatar
Dongz committed
4

PengGao's avatar
PengGao committed
5
from lightx2v.utils.envs import *
6
from lightx2v.utils.global_paras import CALIB
PengGao's avatar
PengGao committed
7
8
9
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
try:
    from lightx2v_kernel.gemm import (
        cutlass_scaled_mxfp4_mm,
        cutlass_scaled_mxfp6_mxfp8_mm,
        cutlass_scaled_mxfp8_mm,
        cutlass_scaled_nvfp4_mm,
        scaled_mxfp4_quant,
        scaled_mxfp6_quant,
        scaled_mxfp8_quant,
        scaled_nvfp4_quant,
    )
except ImportError:
    scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
    scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
    scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
    scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None

gushiqiao's avatar
gushiqiao committed
27
28
29
30
31
32
33
34
35
36
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

37
try:
gushiqiao's avatar
gushiqiao committed
38
    from q8_kernels.functional.linear import q8_linear
39
except ImportError:
gushiqiao's avatar
gushiqiao committed
40
41
42
43
44
45
    q8_linear = None

try:
    from q8_kernels.functional.linear import fp8_linear
except ImportError:
    fp8_linear = None
helloyongyang's avatar
helloyongyang committed
46

47
48
49
50
51
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
52
try:
Wq-dd's avatar
Wq-dd committed
53
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
gushiqiao's avatar
gushiqiao committed
54
55
56
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

57
58
59
60
61
try:
    import gguf
except ImportError:
    gguf = None

62
63
64
65
try:
    import marlin_cuda_quant
except ModuleNotFoundError:
    marlin_cuda_quant = None
helloyongyang's avatar
helloyongyang committed
66

67

helloyongyang's avatar
helloyongyang committed
68
class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
69
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
70
71
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
72
73
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
74
75
76
77
78
79
80
81
82
83
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

84
85
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
86

gushiqiao's avatar
gushiqiao committed
87
    def to_cuda(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
88
89
90
91
92
        self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_weight_scale"):
            self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_bias") and self.pin_bias is not None:
            self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
93

94
    def to_cpu(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
95
96
97
98
99
100
101
102
103
104
105
106
        if hasattr(self, "pin_weight"):
            self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)
107

helloyongyang's avatar
helloyongyang committed
108

Dongz's avatar
Dongz committed
109
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
110
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
111
112
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
113
114

    def load(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
115
        device = weight_dict[self.weight_name].device
116
117
118
119
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name].t()
            if self.bias_name is not None:
                self.bias = weight_dict[self.bias_name]
120
121
            else:
                self.bias = None
122
123
124
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].t().shape
            weight_dtype = weight_dict[self.weight_name].dtype
gushiqiao's avatar
gushiqiao committed
125
126
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name].t())
gushiqiao's avatar
gushiqiao committed
127

128
129
130
            if self.bias_name is not None:
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
131
132
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
133
            else:
134
                self.bias = None
gushiqiao's avatar
gushiqiao committed
135
                self.pin_bias = None
136
            del weight_dict[self.weight_name]
gushiqiao's avatar
gushiqiao committed
137
        else:
138
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
helloyongyang's avatar
helloyongyang committed
139

140
141
142
143
144
145
    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size()

helloyongyang's avatar
helloyongyang committed
146
147
148
149
150
151
152
153
154
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
155
156
157
158
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
159
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
160
161
162
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
163

Dongz's avatar
Dongz committed
164
@MM_WEIGHT_REGISTER("Default-Force-FP32")
165
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
166
167
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
168
169
170
171

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
172
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
173
174
175
            self.bias = self.bias.to(torch.float32)


176
class MMWeightQuantTemplate(MMWeightTemplate):
177
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
178
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
179
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
180
181
182
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
183
184
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
185
        self.infer_dtype = GET_DTYPE()
186

helloyongyang's avatar
helloyongyang committed
187
188
189
    # =========================
    # weight load functions
    # =========================
190

191
    def load_from_disk(self):  # Need Rewrite
192
193
194
195
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
196
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
197
198
199
200
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
201
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
202

helloyongyang's avatar
helloyongyang committed
203
204
        if self.weight_need_transpose:
            self.weight = self.weight.t()
205

206
207
208
209
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
gushiqiao's avatar
gushiqiao committed
210
211
212
213
                if hasattr(self, "weight"):
                    self.weight = self.weight.t()
                elif hasattr(self, "pin_weight"):
                    self.pin_weight = self.pin_weight.t()
214
215

    def clear(self):
gushiqiao's avatar
gushiqiao committed
216
        attrs = ["weight", "weight_scale", "bias", "pin_weight", "pin_weight_scale", "pin_bias"]
217
218
219
220
221
222
223
224
225
226
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()
        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

227
    def load_quantized(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
228
        device = weight_dict[self.weight_name].device
229
230
231
232
233
234
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name]
            self.weight_scale = weight_dict[self.weight_scale_name].float()
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].shape
            weight_dtype = weight_dict[self.weight_name].dtype
gushiqiao's avatar
gushiqiao committed
235
236
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name])
237
238
239

            weight_scale_shape = weight_dict[self.weight_scale_name].shape
            weight_scale_dtype = torch.float
gushiqiao's avatar
gushiqiao committed
240
241
242
            self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
            self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
            del weight_dict[self.weight_name]
gushiqiao's avatar
gushiqiao committed
243
        else:
244
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
245
246

    def load_fp8_perchannel_sym(self, weight_dict):
247
        if self.config.get("weight_auto_quant", False):
248
            self.weight = weight_dict[self.weight_name].to(torch.float32)
249
250
251
252
253
254
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
255
256

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
257
            device = weight_dict[self.bias_name].device
258
259
260
261
262
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
263
264
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
265
266
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
267
268
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
269
            self.pin_bias = None
270
271

    def load_int8_perchannel_sym(self, weight_dict):
272
        if self.config.get("weight_auto_quant", False):
273
            self.weight = weight_dict[self.weight_name].to(torch.float32)
274
275
276
277
278
279
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
280
281

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
282
            device = weight_dict[self.bias_name].device
283
284
285
286
287
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
288
289
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
290
291
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
292
293
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
294
            self.pin_bias = None
295

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    def load_mxfp4(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
            self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
            self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
            if device.type == "cuda":
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
            elif device.type == "cpu":
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")

        if self.bias_name is not None:
            device = weight_dict[self.bias_name].device
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
        else:
            self.bias = None
            self.pin_bias = None

    def load_mxfp6(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
            self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
            self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
            if device.type == "cuda":
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
            elif device.type == "cpu":
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")

        if self.bias_name is not None:
            device = weight_dict[self.bias_name].device
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
        else:
            self.bias = None
            self.pin_bias = None

    def load_mxfp8(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
            self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16)
            self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
            if device.type == "cuda":
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
            elif device.type == "cpu":
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")

        if self.bias_name is not None:
            device = weight_dict[self.bias_name].device
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
        else:
            self.bias = None
            self.pin_bias = None

    def load_nvfp4(self, weight_dict):
        device = weight_dict[self.weight_name].device

        input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")]
        input_global_scale = (2688.0 / input_absmax).to(torch.float32)
        weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
        alpha = 1.0 / (input_global_scale * weight_global_scale)

        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name]
            self.weight_scale = weight_dict[self.weight_scale_name]
            self.input_global_scale = input_global_scale
            self.alpha = alpha
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].shape
            weight_dtype = weight_dict[self.weight_name].dtype
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name])

            weight_scale_shape = weight_dict[self.weight_scale_name].shape
            weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
            self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
            self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])

            input_global_scale_shape = input_global_scale.shape
            input_global_scale_dtype = input_global_scale.dtype
            self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype)
            self.pin_input_global_scale.copy_(input_global_scale)

            alpha_shape = alpha.shape
            alpha_dtype = alpha.dtype
            self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype)
            self.pin_alpha.copy_(alpha)

            del weight_dict[self.weight_name]
        else:
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")

        if self.bias_name is not None:
            device = weight_dict[self.bias_name].device
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
        else:
            self.bias = None
            self.pin_bias = None

469
    def load_fp8_perblock128_sym(self, weight_dict):
470
        if self.config.get("weight_auto_quant", False):
471
            self.weight = weight_dict[self.weight_name]
472
473
474
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
        else:
            self.load_quantized(weight_dict)
475
476

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
477
            device = weight_dict[self.bias_name].device
478
479
480
481
482
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
483
484
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
485
486
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
487
488
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
489
            self.pin_bias = None
490
491
492
493

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
494
495
496
497
498
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
499
500
501
502
503
504
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
505
506
507
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
508
509
510
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

527
528
529
530
531
532
533
534
535
536
537
538
    def act_quant_nvfp4(self, x):
        input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_mxfp4(self, x):
        input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
        return input_tensor_quant, input_tensor_scale

    def act_quant_mxfp8(self, x):
        input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
        return input_tensor_quant, input_tensor_scale

539
540
541
542
543
544
545
546
547
548
549
    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
550
551
552
553
554
555
556
557
558
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
559
560
        return input_tensor_quant, input_tensor_scale

561
562
563
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
564
565
566
567
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
568
        if hasattr(self, "bias") and self.bias is not None:
569
570
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
571
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
572
573
        return destination

574

575
@MM_WEIGHT_REGISTER("fp8-vllm")
576
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
577
    """
helloyongyang's avatar
helloyongyang committed
578
579
580
581
582
583
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
584
585
    """

586
587
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
588
589
590
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
591
592
593
594
595
596

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
597
598

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
599
600
601
602
603
604
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
gushiqiao's avatar
gushiqiao committed
605
            self.bias if self.bias is not None else None,
606
        )
helloyongyang's avatar
helloyongyang committed
607
608
609
        return output_tensor


610
@MM_WEIGHT_REGISTER("int8-vllm")
611
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
612
    """
helloyongyang's avatar
helloyongyang committed
613
614
615
616
617
618
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
619
620
    """

621
622
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
623
624
625
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
626
627
628
629
630
631

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
632
633

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
634
635
636
637
638
639
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
gushiqiao's avatar
gushiqiao committed
640
            self.bias if self.bias is not None else None,
641
        )
helloyongyang's avatar
helloyongyang committed
642
643
644
        return output_tensor


645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
@MM_WEIGHT_REGISTER("mxfp4")
class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp4-A-mxfp4-dynamic

    Quant MM:
        Weight: mxfp4
        Act: mxfp4
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_mxfp4
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp4
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp6-A-nvfp8-dynamic

    Quant MM:
        Weight: mxfp6
        Act: mxfp8
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_mxfp6
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp8
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("mxfp8")
class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp8-A-nvfp8-dynamic

    Quant MM:
        Weight: mxfp8
        Act: mxfp8
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_mxfp8
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp8
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("nvfp4")
class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
    """
    Name: W-nvfp4-A-nvfp4-dynamic

    Quant MM:
        Weight: nvfp4
        Act: nvfp4
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_nvfp4
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_nvfp4

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor

    def to_cuda(self, non_blocking=False):
        self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_weight_scale"):
            self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
            self.input_global_scale = self.pin_input_global_scale.cuda(non_blocking=non_blocking)
            self.alpha = self.pin_alpha.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_bias") and self.pin_bias is not None:
            self.bias = self.pin_bias.cuda(non_blocking=non_blocking)

    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pin_weight"):
            self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
                self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
                self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
                self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
                self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)


@MM_WEIGHT_REGISTER("Calib")
class MMCalibNvfp4(MMWeight):
    """
    Name: calib

    Calib:
        absmax: torch.max(torch.abs(input_tensor))
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.running_absmax = None
        self.count = 0
        self.decay = 0.9

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype, device = input_tensor.dtype, input_tensor.device

        current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
        if self.count % 2 == 0:
            if self.running_absmax is None:
                self.running_absmax = current_absmax
            else:
                self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
            CALIB["absmax"][self.weight_name] = self.running_absmax
        self.count = self.count + 1

        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)


809
@MM_WEIGHT_REGISTER("fp8-q8f")
810
811
812
813
814
815
816
817
818
819
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

820
821
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
822
823
824
825
826
827
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
gushiqiao's avatar
gushiqiao committed
828
        output_tensor = fp8_linear(
829
830
            input_tensor_quant,
            self.weight,
gushiqiao's avatar
gushiqiao committed
831
            self.bias.float() if self.bias is not None else None,
832
833
            input_tensor_scale,
            self.weight_scale,
834
            out_dtype=self.infer_dtype,
835
        )
836
837
838
        return output_tensor.squeeze(0)


839
@MM_WEIGHT_REGISTER("int8-q8f")
840
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
841
    """
842
843
844
845
846
847
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
848
849
    """

850
851
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
852
853
854
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
855

856
857
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
gushiqiao's avatar
gushiqiao committed
858
        output_tensor = q8_linear(
859
860
            input_tensor_quant,
            self.weight,
gushiqiao's avatar
gushiqiao committed
861
            self.bias.float() if self.bias is not None else None,
862
863
864
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
865
            out_dtype=self.infer_dtype,
866
        )
867
868
869
        return output_tensor.squeeze(0)


870
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
871
872
873
874
875
876
877
878
879
880
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

881
882
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
883
884
885
886
887
888
889
890
891
892
893
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
894
895
896
897
898
899
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
900
901
902
903
            output_tensor.add_(self.bias)
        return output_tensor


904
@MM_WEIGHT_REGISTER("fp8-sgl")
905
906
907
908
909
910
911
912
913
914
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

915
916
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
917
918
919
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
920
921

    def apply(self, input_tensor):
922
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
923
924
925
926
927
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
928
            self.infer_dtype,
929
930
            bias=self.bias,
        )
931
932
933
        return output_tensor


934
@MM_WEIGHT_REGISTER("int8-sgl")
helloyongyang's avatar
helloyongyang committed
935
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
936
937
938
939
940
941
942
943
944
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

945
946
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
947
948
949
950
951
952
953
954
955
956
957
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
958
959
960
961
962
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
963
            self.infer_dtype,
gushiqiao's avatar
gushiqiao committed
964
            self.bias if self.bias is not None else None,
965
        )
966
        return output_tensor
967
968


969
@MM_WEIGHT_REGISTER("int8-torchao")
gushiqiao's avatar
gushiqiao committed
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
989
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
990
991
992
993
994
995
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
    TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

    def dequantize_func(self):
        # TODO: implement dequantize_func
        pass


@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

1012

1013
@MM_WEIGHT_REGISTER("int4-g128-marlin")
1014
1015
1016
1017
1018
1019
1020
1021
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
    """
    Name: "W-int4-group128-sym-Marlin

    Quant int4 x FP16:
        Weight: int4 pergroup sym
        Kernel: Marlin
    """
1022

1023
1024
1025
1026
1027
1028
1029
1030
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_quantized

    def load(self, weight_dict):
        assert not self.lazy_load
        self.load_func(weight_dict)
        self.workspace = weight_dict[f"{self.weight_name}_workspace"]
gushiqiao's avatar
gushiqiao committed
1031

1032
        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
1033
1034
            bias_shape = weight_dict[self.bias_name].shape
            bias_dtype = weight_dict[self.bias_name].dtype
1035
1036
            self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
            self.bias.copy_(weight_dict[self.bias_name])
1037
1038
        else:
            self.bias = None
1039

1040
1041
1042
1043
1044
1045
    def apply(self, input_tensor):
        output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
        marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
        if hasattr(self, "bias") and self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor