mm_weight.py 26.8 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
2
3

import torch
Dongz's avatar
Dongz committed
4

PengGao's avatar
PengGao committed
5
6
7
8
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER

gushiqiao's avatar
gushiqiao committed
9
10
11
12
13
14
15
16
17
18
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

19
20
21
22
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
23

24
25
26
27
28
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
29
try:
Wq-dd's avatar
Wq-dd committed
30
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
gushiqiao's avatar
gushiqiao committed
31
32
33
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

34
35
36
37
38
try:
    import gguf
except ImportError:
    gguf = None

39
40
41
42
try:
    import marlin_cuda_quant
except ModuleNotFoundError:
    marlin_cuda_quant = None
helloyongyang's avatar
helloyongyang committed
43

44

helloyongyang's avatar
helloyongyang committed
45
class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
46
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
47
48
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
49
50
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
51
52
53
54
55
56
57
58
59
60
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

61
62
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
63

gushiqiao's avatar
gushiqiao committed
64
    def to_cuda(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
65
66
67
68
69
        self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_weight_scale"):
            self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
        if hasattr(self, "pin_bias") and self.pin_bias is not None:
            self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
70

71
    def to_cpu(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
72
73
74
75
76
77
78
79
80
81
82
83
        if hasattr(self, "pin_weight"):
            self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)
84

helloyongyang's avatar
helloyongyang committed
85

Dongz's avatar
Dongz committed
86
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
87
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
88
89
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
90
91

    def load(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
92
        device = weight_dict[self.weight_name].device
93
94
95
96
97
98
99
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name].t()
            if self.bias_name is not None:
                self.bias = weight_dict[self.bias_name]
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].t().shape
            weight_dtype = weight_dict[self.weight_name].dtype
gushiqiao's avatar
gushiqiao committed
100
101
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name].t())
gushiqiao's avatar
gushiqiao committed
102

103
104
105
            if self.bias_name is not None:
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
106
107
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
108
            else:
gushiqiao's avatar
gushiqiao committed
109
                self.pin_bias = None
110
            del weight_dict[self.weight_name]
gushiqiao's avatar
gushiqiao committed
111
        else:
112
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
helloyongyang's avatar
helloyongyang committed
113

114
115
116
117
118
119
    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size()

helloyongyang's avatar
helloyongyang committed
120
121
122
123
124
125
126
127
128
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
129
130
131
132
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
133
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
134
135
136
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
137

Dongz's avatar
Dongz committed
138
@MM_WEIGHT_REGISTER("Default-Force-FP32")
139
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
140
141
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
142
143
144
145

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
146
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
147
148
149
            self.bias = self.bias.to(torch.float32)


150
class MMWeightQuantTemplate(MMWeightTemplate):
151
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
152
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
153
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
154
155
156
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
157
158
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
159
        self.infer_dtype = GET_DTYPE()
160

helloyongyang's avatar
helloyongyang committed
161
162
163
    # =========================
    # weight load functions
    # =========================
164

165
    def load_from_disk(self):  # Need Rewrite
166
167
168
169
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
170
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
171
172
173
174
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
175
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
176

helloyongyang's avatar
helloyongyang committed
177
178
        if self.weight_need_transpose:
            self.weight = self.weight.t()
179

180
181
182
183
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
gushiqiao's avatar
gushiqiao committed
184
185
186
187
                if hasattr(self, "weight"):
                    self.weight = self.weight.t()
                elif hasattr(self, "pin_weight"):
                    self.pin_weight = self.pin_weight.t()
188
189

    def clear(self):
gushiqiao's avatar
gushiqiao committed
190
        attrs = ["weight", "weight_scale", "bias", "pin_weight", "pin_weight_scale", "pin_bias"]
191
192
193
194
195
196
197
198
199
200
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()
        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

201
    def load_quantized(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
202
        device = weight_dict[self.weight_name].device
203
204
205
206
207
208
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name]
            self.weight_scale = weight_dict[self.weight_scale_name].float()
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].shape
            weight_dtype = weight_dict[self.weight_name].dtype
gushiqiao's avatar
gushiqiao committed
209
210
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name])
211
212
213

            weight_scale_shape = weight_dict[self.weight_scale_name].shape
            weight_scale_dtype = torch.float
gushiqiao's avatar
gushiqiao committed
214
215
216
            self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
            self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
            del weight_dict[self.weight_name]
gushiqiao's avatar
gushiqiao committed
217
        else:
218
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
219
220

    def load_fp8_perchannel_sym(self, weight_dict):
221
        if self.config.get("weight_auto_quant", False):
222
            self.weight = weight_dict[self.weight_name].to(torch.float32)
223
224
225
226
227
228
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
229
230

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
231
            device = weight_dict[self.bias_name].device
232
233
234
235
236
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
237
238
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
239
240
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
241
242
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
243
            self.pin_bias = None
244
245

    def load_int8_perchannel_sym(self, weight_dict):
246
        if self.config.get("weight_auto_quant", False):
247
            self.weight = weight_dict[self.weight_name].to(torch.float32)
248
249
250
251
252
253
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
254
255

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
256
            device = weight_dict[self.bias_name].device
257
258
259
260
261
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
262
263
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
264
265
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
266
267
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
268
            self.pin_bias = None
269
270

    def load_fp8_perblock128_sym(self, weight_dict):
271
        if self.config.get("weight_auto_quant", False):
272
            self.weight = weight_dict[self.weight_name]
273
274
275
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
        else:
            self.load_quantized(weight_dict)
276
277

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
278
            device = weight_dict[self.bias_name].device
279
280
281
282
283
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
gushiqiao's avatar
gushiqiao committed
284
285
                self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.pin_bias.copy_(weight_dict[self.bias_name])
286
287
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
288
289
        else:
            self.bias = None
gushiqiao's avatar
gushiqiao committed
290
            self.pin_bias = None
291
292
293
294

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
295
296
297
298
299
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
300
301
302
303
304
305
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
306
307
308
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
309
310
311
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
339
340
341
342
343
344
345
346
347
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
348
349
        return input_tensor_quant, input_tensor_scale

350
351
352
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
353
354
355
356
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
357
        if hasattr(self, "bias") and self.bias is not None:
358
359
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
360
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
361
362
        return destination

363

364
@MM_WEIGHT_REGISTER("fp8-vllm")
365
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
366
    """
helloyongyang's avatar
helloyongyang committed
367
368
369
370
371
372
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
373
374
    """

375
376
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
377
378
379
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
380
381
382
383
384
385

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
386
387

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
388
389
390
391
392
393
394
395
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
396
397
398
        return output_tensor


399
@MM_WEIGHT_REGISTER("int8-vllm")
400
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
401
    """
helloyongyang's avatar
helloyongyang committed
402
403
404
405
406
407
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
408
409
    """

410
411
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
412
413
414
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
415
416
417
418
419
420

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
421
422

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
423
424
425
426
427
428
429
430
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
431
432
433
        return output_tensor


434
@MM_WEIGHT_REGISTER("fp8-q8f")
435
436
437
438
439
440
441
442
443
444
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

445
446
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
447
448
449
450
451
452
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
453
454
455
456
457
458
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
459
            out_dtype=self.infer_dtype,
460
        )
461
462
463
        return output_tensor.squeeze(0)


464
@MM_WEIGHT_REGISTER("int8-q8f")
465
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
466
    """
467
468
469
470
471
472
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
473
474
    """

475
476
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
477
478
479
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
480

481
482
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
483
484
485
486
487
488
489
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
490
            out_dtype=self.infer_dtype,
491
        )
492
493
494
        return output_tensor.squeeze(0)


495
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
496
497
498
499
500
501
502
503
504
505
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

506
507
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
508
509
510
511
512
513
514
515
516
517
518
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
519
520
521
522
523
524
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
525
526
527
528
            output_tensor.add_(self.bias)
        return output_tensor


529
@MM_WEIGHT_REGISTER("fp8-sgl")
530
531
532
533
534
535
536
537
538
539
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

540
541
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
542
543
544
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
545
546

    def apply(self, input_tensor):
547
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
548
549
550
551
552
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
553
            self.infer_dtype,
554
555
            bias=self.bias,
        )
556
557
558
        return output_tensor


559
@MM_WEIGHT_REGISTER("int8-sgl")
helloyongyang's avatar
helloyongyang committed
560
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
561
562
563
564
565
566
567
568
569
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

570
571
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
572
573
574
575
576
577
578
579
580
581
582
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
583
584
585
586
587
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
588
            self.infer_dtype,
589
590
            self.bias,
        )
591
        return output_tensor
592
593


594
@MM_WEIGHT_REGISTER("int8-torchao")
gushiqiao's avatar
gushiqiao committed
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
614
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
615
616
617
618
619
620
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
    TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

    def dequantize_func(self):
        # TODO: implement dequantize_func
        pass


@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

637

638
@MM_WEIGHT_REGISTER("int4-g128-marlin")
639
640
641
642
643
644
645
646
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
    """
    Name: "W-int4-group128-sym-Marlin

    Quant int4 x FP16:
        Weight: int4 pergroup sym
        Kernel: Marlin
    """
647

648
649
650
651
652
653
654
655
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_quantized

    def load(self, weight_dict):
        assert not self.lazy_load
        self.load_func(weight_dict)
        self.workspace = weight_dict[f"{self.weight_name}_workspace"]
gushiqiao's avatar
gushiqiao committed
656

657
        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
658
659
            bias_shape = weight_dict[self.bias_name].shape
            bias_dtype = weight_dict[self.bias_name].dtype
660
661
            self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
            self.bias.copy_(weight_dict[self.bias_name])
662
663
        else:
            self.bias = None
664

665
666
667
668
669
670
    def apply(self, input_tensor):
        output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
        marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
        if hasattr(self, "bias") and self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor