mm_weight.py 28.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
2
3

import torch
root's avatar
root committed
4
from loguru import logger
Dongz's avatar
Dongz committed
5

PengGao's avatar
PengGao committed
6
7
8
9
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER

gushiqiao's avatar
gushiqiao committed
10
11
12
13
14
15
16
17
18
19
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

20
21
22
23
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
24

25
26
27
28
29
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
30
try:
Wq-dd's avatar
Wq-dd committed
31
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
gushiqiao's avatar
gushiqiao committed
32
33
34
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

35
36
37
38
39
try:
    import gguf
except ImportError:
    gguf = None

helloyongyang's avatar
helloyongyang committed
40
41

class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
42
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
43
44
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
45
46
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
47
48
49
50
51
52
53
54
55
56
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

57
58
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
59

gushiqiao's avatar
gushiqiao committed
60
61
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
62
63
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
64
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
65
66
            self.bias = self.bias.cuda(non_blocking=non_blocking)

67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
81

Dongz's avatar
Dongz committed
82
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
83
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
84
85
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
86
87

    def load(self, weight_dict):
88
        self.weight = weight_dict[self.weight_name].t()
Xinchi Huang's avatar
Xinchi Huang committed
89
        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
90
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
Xinchi Huang's avatar
Xinchi Huang committed
91
        self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
helloyongyang's avatar
helloyongyang committed
92
93
94
95
96
97
98
99
100
101

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
102
103
104
105
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
106
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
107
108
109
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
110

Dongz's avatar
Dongz committed
111
@MM_WEIGHT_REGISTER("Default-Force-FP32")
112
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
113
114
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
115
116
117
118

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
119
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
120
121
122
            self.bias = self.bias.to(torch.float32)


123
class MMWeightQuantTemplate(MMWeightTemplate):
124
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
125
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
126
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
127
128
129
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
130
131
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
132
        self.infer_dtype = GET_DTYPE()
133

helloyongyang's avatar
helloyongyang committed
134
135
136
    # =========================
    # weight load functions
    # =========================
137

138
139
140
141
142
    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
143
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
144
145
146
147
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
148
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
149

helloyongyang's avatar
helloyongyang committed
150
151
        if self.weight_need_transpose:
            self.weight = self.weight.t()
152

153
154
155
156
157
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()
gushiqiao's avatar
Fix  
gushiqiao committed
158
                self.pinned_weight = self.pinned_weight.t()
159
160

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
161
        attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
162
163
164
165
166
167
168
169
170
171
172
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

173
    def load_quantized(self, weight_dict):
174
        self.weight = weight_dict[self.weight_name]
175
176
177
178
        self.weight_scale = weight_dict[self.weight_scale_name].float()

        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
        self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
179
180

    def load_fp8_perchannel_sym(self, weight_dict):
181
        if self.config.get("weight_auto_quant", False):
182
            self.weight = weight_dict[self.weight_name].to(torch.float32)
183
184
185
186
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
187
188
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
189
190
        else:
            self.load_quantized(weight_dict)
191
192
193
194
195
196

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
197
198

    def load_int8_perchannel_sym(self, weight_dict):
199
        if self.config.get("weight_auto_quant", False):
200
            self.weight = weight_dict[self.weight_name].to(torch.float32)
201
202
203
204
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
205
206
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
207
208
        else:
            self.load_quantized(weight_dict)
209
210
211
212
213
214

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
215
216

    def load_fp8_perblock128_sym(self, weight_dict):
217
        if self.config.get("weight_auto_quant", False):
218
            self.weight = weight_dict[self.weight_name]
219
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
220
221
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
222
223
        else:
            self.load_quantized(weight_dict)
224
225
226
227
228
229

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
230
231
232
233

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
234
235
236
237
238
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
239
240
241
242
243
244
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
245
246
247
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
248
249
250
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
278
279
280
281
282
283
284
285
286
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
287
288
        return input_tensor_quant, input_tensor_scale

289
290
291
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
292
293
294
295
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
296
        if hasattr(self, "bias") and self.bias is not None:
297
298
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
299
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
300
301
        return destination

302

Dongz's avatar
Dongz committed
303
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
304
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
305
    """
helloyongyang's avatar
helloyongyang committed
306
307
308
309
310
311
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
312
313
    """

314
315
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
316
317
318
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
319
320
321
322
323
324

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
325
326

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
327
328
329
330
331
332
333
334
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
335
336
337
        return output_tensor


Dongz's avatar
Dongz committed
338
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
339
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
340
    """
helloyongyang's avatar
helloyongyang committed
341
342
343
344
345
346
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
347
348
    """

349
350
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
351
352
353
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
354
355
356
357
358
359

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
360
361

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
362
363
364
365
366
367
368
369
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
370
371
372
        return output_tensor


373
374
375
376
377
378
379
380
381
382
383
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

384
385
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
386
387
388
389
390
391
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
392
393
394
395
396
397
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
398
            out_dtype=self.infer_dtype,
399
        )
400
401
402
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
403
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
404
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
405
    """
406
407
408
409
410
411
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
412
413
    """

414
415
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
416
417
418
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
419

420
421
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
422
423
424
425
426
427
428
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
429
            out_dtype=self.infer_dtype,
430
        )
431
432
433
        return output_tensor.squeeze(0)


434
435
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
436
    """
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
453
        Out : torch.Size([1024, 4096]), self.infer_dtype
454
455
    """

456
457
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
458
459
460
461
462
463
464
465
466
467
468
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
469
470
471
472
473
474
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

490
491
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
492
493
494
495
496
497
498
499
500
501
502
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
503
504
505
506
507
508
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
509
510
511
512
513
514
515
516
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
517
518
519
520

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
521
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
522
523
    """

524
525
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
526
527
528
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
529

530
531
532
533
534
535
536
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
537
538
539
540
541
542
543
544
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
545
546
547
        return output_tensor


helloyongyang's avatar
helloyongyang committed
548
549
550
551
552
553
554
555
556
557
558
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

559
560
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
561
562
563
564
565
566
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
567
568
569
570
571
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
572
            self.infer_dtype,
573
574
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
575
576
577
        return output_tensor


578
579
580
581
582
583
584
585
586
587
588
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

589
590
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
591
592
593
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
594
595

    def apply(self, input_tensor):
596
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
597
598
599
600
601
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
602
            self.infer_dtype,
603
604
            bias=self.bias,
        )
605
606
607
608
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
609
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
610
611
612
613
614
615
616
617
618
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

619
620
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
621
622
623
624
625
626
627
628
629
630
631
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
632
633
634
635
636
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
637
            self.infer_dtype,
638
639
            self.bias,
        )
640
        return output_tensor
641
642


gushiqiao's avatar
gushiqiao committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
663
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
664
665
666
667
668
669
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
    TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)

    def dequantize_func(self):
        # TODO: implement dequantize_func
        pass


@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)


Dongz's avatar
Dongz committed
687
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
688
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
689
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
690
691
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
692
693
    }

Dongz's avatar
Dongz committed
694
695
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
696
697
698
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
699
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
700
701

    weight_dict = {
Dongz's avatar
Dongz committed
702
703
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
704
705
    }

Dongz's avatar
Dongz committed
706
707
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
708
709
710
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
711
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
712
713

    weight_dict = {
Dongz's avatar
Dongz committed
714
715
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
716
717
    }

Dongz's avatar
Dongz committed
718
719
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
720
721
722
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
723
    logger.info(output_tensor.shape)