mm_weight.py 32 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
2
3

import torch
root's avatar
root committed
4
from loguru import logger
Dongz's avatar
Dongz committed
5

PengGao's avatar
PengGao committed
6
7
8
9
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER

gushiqiao's avatar
gushiqiao committed
10
11
12
13
14
15
16
17
18
19
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

20
21
22
23
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
24

25
26
27
28
29
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
30
try:
Wq-dd's avatar
Wq-dd committed
31
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
gushiqiao's avatar
gushiqiao committed
32
33
34
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

35
36
37
38
39
try:
    import gguf
except ImportError:
    gguf = None

40
41
42
43
try:
    import marlin_cuda_quant
except ModuleNotFoundError:
    marlin_cuda_quant = None
helloyongyang's avatar
helloyongyang committed
44

45

helloyongyang's avatar
helloyongyang committed
46
class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
47
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
48
49
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
50
51
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
52
53
54
55
56
57
58
59
60
61
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

62
63
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
64

gushiqiao's avatar
gushiqiao committed
65
66
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
67
68
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
69
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
70
71
            self.bias = self.bias.cuda(non_blocking=non_blocking)

72
    def to_cpu(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
73
74
75
76
77
        self.weight = self.weight.to("cpu", non_blocking=non_blocking)
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
        if hasattr(self, "bias") and self.bias is not None:
            self.bias = self.bias.to("cpu", non_blocking=non_blocking)
78

helloyongyang's avatar
helloyongyang committed
79

Dongz's avatar
Dongz committed
80
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
81
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
82
83
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
84
85

    def load(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
86
        device = weight_dict[self.weight_name].device
87
88
89
90
91
92
93
94
95
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name].t()
            if self.bias_name is not None:
                self.bias = weight_dict[self.bias_name]
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].t().shape
            weight_dtype = weight_dict[self.weight_name].dtype
            self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.weight.copy_(weight_dict[self.weight_name].t())
gushiqiao's avatar
gushiqiao committed
96

97
98
99
100
101
102
103
            if self.bias_name is not None:
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.bias.copy_(weight_dict[self.bias_name])
            else:
                self.bias = None
gushiqiao's avatar
gushiqiao committed
104
        else:
105
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
helloyongyang's avatar
helloyongyang committed
106

107
108
109
110
111
112
    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size()

helloyongyang's avatar
helloyongyang committed
113
114
115
116
117
118
119
120
121
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
122
123
124
125
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
126
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
127
128
129
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
130

Dongz's avatar
Dongz committed
131
@MM_WEIGHT_REGISTER("Default-Force-FP32")
132
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
133
134
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
135
136
137
138

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
139
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
140
141
142
            self.bias = self.bias.to(torch.float32)


143
class MMWeightQuantTemplate(MMWeightTemplate):
144
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
145
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
146
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
147
148
149
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
150
151
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
152
        self.infer_dtype = GET_DTYPE()
153

helloyongyang's avatar
helloyongyang committed
154
155
156
    # =========================
    # weight load functions
    # =========================
157

158
    def load_from_disk(self):  # Need Rewrite
159
160
161
162
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
163
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
164
165
166
167
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
168
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
169

helloyongyang's avatar
helloyongyang committed
170
171
        if self.weight_need_transpose:
            self.weight = self.weight.t()
172

173
174
175
176
177
178
179
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
180
        attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
181
182
183
184
185
186
187
188
189
190
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()
        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

191
    def load_quantized(self, weight_dict):
gushiqiao's avatar
gushiqiao committed
192
        device = weight_dict[self.weight_name].device
193
194
195
196
197
198
199
200
201
202
203
204
205
        if device.type == "cuda":
            self.weight = weight_dict[self.weight_name]
            self.weight_scale = weight_dict[self.weight_scale_name].float()
        elif device.type == "cpu":
            weight_shape = weight_dict[self.weight_name].shape
            weight_dtype = weight_dict[self.weight_name].dtype
            self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.weight.copy_(weight_dict[self.weight_name])

            weight_scale_shape = weight_dict[self.weight_scale_name].shape
            weight_scale_dtype = torch.float
            self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
            self.weight_scale.copy_(weight_dict[self.weight_scale_name])
gushiqiao's avatar
gushiqiao committed
206
        else:
207
            raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
208
209

    def load_fp8_perchannel_sym(self, weight_dict):
210
        if self.config.get("weight_auto_quant", False):
211
            self.weight = weight_dict[self.weight_name].to(torch.float32)
212
213
214
215
216
217
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
218
219

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
220
            device = weight_dict[self.bias_name].device
221
222
223
224
225
226
227
228
229
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
230
231
        else:
            self.bias = None
232
233

    def load_int8_perchannel_sym(self, weight_dict):
234
        if self.config.get("weight_auto_quant", False):
235
            self.weight = weight_dict[self.weight_name].to(torch.float32)
236
237
238
239
240
241
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
242
243

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
244
            device = weight_dict[self.bias_name].device
245
246
247
248
249
250
251
252
253
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
254
255
        else:
            self.bias = None
256
257

    def load_fp8_perblock128_sym(self, weight_dict):
258
        if self.config.get("weight_auto_quant", False):
259
            self.weight = weight_dict[self.weight_name]
260
261
262
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
        else:
            self.load_quantized(weight_dict)
263
264

        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
265
            device = weight_dict[self.bias_name].device
266
267
268
269
270
271
272
273
274
            if device.type == "cuda":
                self.bias = weight_dict[self.bias_name]
            elif device.type == "cpu":
                bias_shape = weight_dict[self.bias_name].shape
                bias_dtype = weight_dict[self.bias_name].dtype
                self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                self.bias.copy_(weight_dict[self.bias_name])
            else:
                raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
275
276
        else:
            self.bias = None
277
278
279
280

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
281
282
283
284
285
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
286
287
288
289
290
291
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
292
293
294
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
295
296
297
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
325
326
327
328
329
330
331
332
333
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
334
335
        return input_tensor_quant, input_tensor_scale

336
337
338
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
339
340
341
342
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
343
        if hasattr(self, "bias") and self.bias is not None:
344
345
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
346
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
347
348
        return destination

349

Dongz's avatar
Dongz committed
350
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
351
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
352
    """
helloyongyang's avatar
helloyongyang committed
353
354
355
356
357
358
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
359
360
    """

361
362
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
363
364
365
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
366
367
368
369
370
371

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
372
373

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
374
375
376
377
378
379
380
381
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
382
383
384
        return output_tensor


Dongz's avatar
Dongz committed
385
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
386
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
387
    """
helloyongyang's avatar
helloyongyang committed
388
389
390
391
392
393
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
394
395
    """

396
397
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
398
399
400
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
401
402
403
404
405
406

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
407
408

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
409
410
411
412
413
414
415
416
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
417
418
419
        return output_tensor


420
421
422
423
424
425
426
427
428
429
430
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

431
432
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
433
434
435
436
437
438
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
439
440
441
442
443
444
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
445
            out_dtype=self.infer_dtype,
446
        )
447
448
449
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
450
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
451
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
452
    """
453
454
455
456
457
458
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
459
460
    """

461
462
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
463
464
465
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
466

467
468
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
469
470
471
472
473
474
475
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
476
            out_dtype=self.infer_dtype,
477
        )
478
479
480
        return output_tensor.squeeze(0)


481
482
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
483
    """
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
500
        Out : torch.Size([1024, 4096]), self.infer_dtype
501
502
    """

503
504
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
505
506
507
508
509
510
511
512
513
514
515
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
516
517
518
519
520
521
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

537
538
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
539
540
541
542
543
544
545
546
547
548
549
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
550
551
552
553
554
555
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
556
557
558
559
560
561
562
563
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
564
565
566
567

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
568
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
569
570
    """

571
572
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
573
574
575
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
576

577
578
579
580
581
582
583
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
584
585
586
587
588
589
590
591
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
592
593
594
        return output_tensor


helloyongyang's avatar
helloyongyang committed
595
596
597
598
599
600
601
602
603
604
605
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

606
607
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
608
609
610
611
612
613
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
614
615
616
617
618
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
619
            self.infer_dtype,
620
621
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
622
623
624
        return output_tensor


625
626
627
628
629
630
631
632
633
634
635
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

636
637
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
638
639
640
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
641
642

    def apply(self, input_tensor):
643
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
644
645
646
647
648
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
649
            self.infer_dtype,
650
651
            bias=self.bias,
        )
652
653
654
655
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
656
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
657
658
659
660
661
662
663
664
665
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

666
667
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
668
669
670
671
672
673
674
675
676
677
678
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
679
680
681
682
683
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
684
            self.infer_dtype,
685
686
            self.bias,
        )
687
        return output_tensor
688
689


gushiqiao's avatar
gushiqiao committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
710
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
711
712
713
714
715
716
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
    TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

    def dequantize_func(self):
        # TODO: implement dequantize_func
        pass


@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

733

734
735
736
737
738
739
740
741
742
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
    """
    Name: "W-int4-group128-sym-Marlin

    Quant int4 x FP16:
        Weight: int4 pergroup sym
        Kernel: Marlin
    """
743

744
745
746
747
748
749
750
751
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_quantized

    def load(self, weight_dict):
        assert not self.lazy_load
        self.load_func(weight_dict)
        self.workspace = weight_dict[f"{self.weight_name}_workspace"]
gushiqiao's avatar
gushiqiao committed
752

753
        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
754
755
            bias_shape = weight_dict[self.bias_name].shape
            bias_dtype = weight_dict[self.bias_name].dtype
756
757
            self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
            self.bias.copy_(weight_dict[self.bias_name])
758
759
        else:
            self.bias = None
760

761
762
763
764
765
766
    def apply(self, input_tensor):
        output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
        marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
        if hasattr(self, "bias") and self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor
767
768


Dongz's avatar
Dongz committed
769
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
770
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
771
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
772
773
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
774
775
    }

Dongz's avatar
Dongz committed
776
777
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
778
779
780
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
781
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
782
783

    weight_dict = {
Dongz's avatar
Dongz committed
784
785
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
786
787
    }

Dongz's avatar
Dongz committed
788
789
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
790
791
792
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
793
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
794
795

    weight_dict = {
Dongz's avatar
Dongz committed
796
797
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
798
799
    }

Dongz's avatar
Dongz committed
800
801
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
802
803
804
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
805
    logger.info(output_tensor.shape)