mm_weight.py 26.7 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
5
from lightx2v.utils.envs import *
root's avatar
root committed
6
from loguru import logger
Dongz's avatar
Dongz committed
7

gushiqiao's avatar
gushiqiao committed
8
9
10
11
12
13
14
15
16
17
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

18
19
20
21
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
22

23
24
25
26
27
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

helloyongyang's avatar
helloyongyang committed
28
29

class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
30
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
31
32
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
33
34
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
35
36
37
38
39
40
41
42
43
44
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

45
46
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
47

gushiqiao's avatar
gushiqiao committed
48
49
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
50
51
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
52
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
53
54
            self.bias = self.bias.cuda(non_blocking=non_blocking)

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
69

Dongz's avatar
Dongz committed
70
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
71
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
72
73
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
74
75

    def load(self, weight_dict):
76
        self.weight = weight_dict[self.weight_name].t()
Xinchi Huang's avatar
Xinchi Huang committed
77
        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
78
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
Xinchi Huang's avatar
Xinchi Huang committed
79
        self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
helloyongyang's avatar
helloyongyang committed
80
81
82
83
84
85
86
87
88
89

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
90
91
92
93
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
94
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
95
96
97
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
98

Dongz's avatar
Dongz committed
99
@MM_WEIGHT_REGISTER("Default-Force-FP32")
100
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
101
102
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
103
104
105
106

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
107
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
108
109
110
            self.bias = self.bias.to(torch.float32)


111
class MMWeightQuantTemplate(MMWeightTemplate):
112
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
113
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
114
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
115
116
117
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
118
119
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
120

helloyongyang's avatar
helloyongyang committed
121
122
123
    # =========================
    # weight load functions
    # =========================
124

125
126
127
128
129
    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
130
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
131
132
133
134
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
135
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
136

helloyongyang's avatar
helloyongyang committed
137
138
        if self.weight_need_transpose:
            self.weight = self.weight.t()
139

140
141
142
143
144
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()
gushiqiao's avatar
Fix  
gushiqiao committed
145
                self.pinned_weight = self.pinned_weight.t()
146
147

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
148
        attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
149
150
151
152
153
154
155
156
157
158
159
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

160
    def load_quantized(self, weight_dict):
161
        self.weight = weight_dict[self.weight_name]
162
163
164
165
        self.weight_scale = weight_dict[self.weight_scale_name].float()

        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
        self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
166
167

    def load_fp8_perchannel_sym(self, weight_dict):
168
        if self.config.get("weight_auto_quant", False):
169
            self.weight = weight_dict[self.weight_name].to(torch.float32)
170
171
172
173
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
174
175
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
176
177
        else:
            self.load_quantized(weight_dict)
178
179
180
181
182
183

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
184
185

    def load_int8_perchannel_sym(self, weight_dict):
186
        if self.config.get("weight_auto_quant", False):
187
            self.weight = weight_dict[self.weight_name].to(torch.float32)
188
189
190
191
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
192
193
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
194
195
        else:
            self.load_quantized(weight_dict)
196
197
198
199
200
201

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
202
203

    def load_fp8_perblock128_sym(self, weight_dict):
204
        if self.config.get("weight_auto_quant", False):
205
            self.weight = weight_dict[self.weight_name]
206
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
207
208
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
209
210
        else:
            self.load_quantized(weight_dict)
211
212
213
214
215
216

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
217
218
219
220

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
221
222
223
224
225
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
226
227
228
229
230
231
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
232
233
234
    # =========================
    # act quant kernels
    # =========================
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
262
263
264
265
266
267
268
269
270
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
271
272
        return input_tensor_quant, input_tensor_scale

273
274
275
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
276
277
278
279
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
280
        if hasattr(self, "bias") and self.bias is not None:
281
282
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
283
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
284
285
        return destination

286

Dongz's avatar
Dongz committed
287
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
288
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
289
    """
helloyongyang's avatar
helloyongyang committed
290
291
292
293
294
295
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
296
297
    """

298
299
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
300
301
302
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
303
304
305
306
307
308

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
309
310

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
311
312
313
314
315
316
317
318
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
319
320
321
        return output_tensor


Dongz's avatar
Dongz committed
322
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
323
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
324
    """
helloyongyang's avatar
helloyongyang committed
325
326
327
328
329
330
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
331
332
    """

333
334
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
335
336
337
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
338
339
340
341
342
343

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
344
345

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
346
347
348
349
350
351
352
353
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
354
355
356
        return output_tensor


357
358
359
360
361
362
363
364
365
366
367
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

368
369
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
370
371
372
373
374
375
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
376
377
378
379
380
381
382
383
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            out_dtype=torch.bfloat16,
        )
384
385
386
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
387
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
388
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
389
    """
390
391
392
393
394
395
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
396
397
    """

398
399
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
400
401
402
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
403

404
405
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
406
407
408
409
410
411
412
413
414
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
            out_dtype=torch.bfloat16,
        )
415
416
417
        return output_tensor.squeeze(0)


418
419
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
420
    """
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
        Out : torch.Size([1024, 4096]), torch.bfloat16
    """

440
441
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
442
443
444
445
446
447
448
449
450
451
452
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
453
454
455
456
457
458
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

474
475
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
476
477
478
479
480
481
482
483
484
485
486
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
487
488
489
490
491
492
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
493
494
495
496
497
498
499
500
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
501
502
503
504

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
505
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
506
507
    """

508
509
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
510
511
512
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
513

514
515
516
517
518
519
520
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
521
522
523
524
525
526
527
528
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
529
530
531
        return output_tensor


helloyongyang's avatar
helloyongyang committed
532
533
534
535
536
537
538
539
540
541
542
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

543
544
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
545
546
547
548
549
550
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
551
552
553
554
555
556
557
558
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
559
560
561
        return output_tensor


562
563
564
565
566
567
568
569
570
571
572
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

573
574
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
575
576
577
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
578
579

    def apply(self, input_tensor):
580
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
581
582
583
584
585
586
587
588
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
589
590
591
592
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
593
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
594
595
596
597
598
599
600
601
602
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

603
604
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
605
606
607
608
609
610
611
612
613
614
615
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
616
617
618
619
620
621
622
623
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            self.bias,
        )
624
        return output_tensor
625
626


Dongz's avatar
Dongz committed
627
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
628
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
629
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
630
631
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
632
633
    }

Dongz's avatar
Dongz committed
634
635
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
636
637
638
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
639
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
640
641

    weight_dict = {
Dongz's avatar
Dongz committed
642
643
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
644
645
    }

Dongz's avatar
Dongz committed
646
647
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
648
649
650
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
651
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
652
653

    weight_dict = {
Dongz's avatar
Dongz committed
654
655
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
656
657
    }

Dongz's avatar
Dongz committed
658
659
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
660
661
662
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
663
    logger.info(output_tensor.shape)