mm_weight.py 26.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
4

gushiqiao's avatar
fix.  
gushiqiao committed
5
import sgl_kernel
helloyongyang's avatar
helloyongyang committed
6
7
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
8
from lightx2v.utils.envs import *
root's avatar
root committed
9
from loguru import logger
Dongz's avatar
Dongz committed
10

11
12
13
14
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
15

16
17
18
19
20
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

helloyongyang's avatar
helloyongyang committed
21
22

class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
23
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
24
25
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
26
27
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
28
29
30
31
32
33
34
35
36
37
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

38
39
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
40

gushiqiao's avatar
gushiqiao committed
41
42
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
43
44
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
45
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
46
47
            self.bias = self.bias.cuda(non_blocking=non_blocking)

48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
62

Dongz's avatar
Dongz committed
63
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
64
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
65
66
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
67
68

    def load(self, weight_dict):
69
        self.weight = weight_dict[self.weight_name].t()
Xinchi Huang's avatar
Xinchi Huang committed
70
        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
71
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
Xinchi Huang's avatar
Xinchi Huang committed
72
        self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
helloyongyang's avatar
helloyongyang committed
73
74
75
76
77
78
79
80
81
82

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
83
84
85
86
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
87
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
88
89
90
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
91

Dongz's avatar
Dongz committed
92
@MM_WEIGHT_REGISTER("Default-Force-FP32")
93
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
94
95
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
96
97
98
99

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
100
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
101
102
103
            self.bias = self.bias.to(torch.float32)


104
class MMWeightQuantTemplate(MMWeightTemplate):
105
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
106
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
107
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
108
109
110
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
111
112
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
113

helloyongyang's avatar
helloyongyang committed
114
115
116
    # =========================
    # weight load functions
    # =========================
117

118
119
120
121
122
    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
123
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
124
125
126
127
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
128
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
129

helloyongyang's avatar
helloyongyang committed
130
131
        if self.weight_need_transpose:
            self.weight = self.weight.t()
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()

    def clear(self):
        attrs = ["weight", "weight_scale", "bias"]
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

152
    def load_quantized(self, weight_dict):
153
        self.weight = weight_dict[self.weight_name]
154
155
156
157
        self.weight_scale = weight_dict[self.weight_scale_name].float()

        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
        self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
158
159

    def load_fp8_perchannel_sym(self, weight_dict):
160
        if self.config.get("weight_auto_quant", False):
161
            self.weight = weight_dict[self.weight_name].to(torch.float32)
162
163
164
165
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
166
167
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
168
169
        else:
            self.load_quantized(weight_dict)
170
171
172
173
174
175

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
176
177

    def load_int8_perchannel_sym(self, weight_dict):
178
        if self.config.get("weight_auto_quant", False):
179
            self.weight = weight_dict[self.weight_name].to(torch.float32)
180
181
182
183
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
184
185
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
186
187
        else:
            self.load_quantized(weight_dict)
188
189
190
191
192
193

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
194
195

    def load_fp8_perblock128_sym(self, weight_dict):
196
        if self.config.get("weight_auto_quant", False):
197
            self.weight = weight_dict[self.weight_name]
198
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
199
200
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
201
202
        else:
            self.load_quantized(weight_dict)
203
204
205
206
207
208

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
209
210
211
212

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
213
214
215
216
217
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
218
219
220
221
222
223
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
224
225
226
    # =========================
    # act quant kernels
    # =========================
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
254
255
256
257
258
259
260
261
262
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
263
264
        return input_tensor_quant, input_tensor_scale

265
266
267
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
268
269
270
271
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
272
        if hasattr(self, "bias") and self.bias is not None:
273
274
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
275
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
276
277
        return destination

278

Dongz's avatar
Dongz committed
279
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
280
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
281
    """
helloyongyang's avatar
helloyongyang committed
282
283
284
285
286
287
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
288
289
    """

290
291
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
292
293
294
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
295
296
297
298
299
300

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
301
302

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
303
304
305
306
307
308
309
310
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
311
312
313
        return output_tensor


Dongz's avatar
Dongz committed
314
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
315
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
316
    """
helloyongyang's avatar
helloyongyang committed
317
318
319
320
321
322
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
323
324
    """

325
326
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
327
328
329
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
330
331
332
333
334
335

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
336
337

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
338
339
340
341
342
343
344
345
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
346
347
348
        return output_tensor


349
350
351
352
353
354
355
356
357
358
359
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

360
361
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
362
363
364
365
366
367
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
368
369
370
371
372
373
374
375
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            out_dtype=torch.bfloat16,
        )
376
377
378
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
379
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
380
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
381
    """
382
383
384
385
386
387
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
388
389
    """

390
391
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
392
393
394
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
395

396
397
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
398
399
400
401
402
403
404
405
406
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
            out_dtype=torch.bfloat16,
        )
407
408
409
        return output_tensor.squeeze(0)


410
411
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
412
    """
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
        Out : torch.Size([1024, 4096]), torch.bfloat16
    """

432
433
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
434
435
436
437
438
439
440
441
442
443
444
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
445
446
447
448
449
450
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

466
467
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
468
469
470
471
472
473
474
475
476
477
478
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
479
480
481
482
483
484
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
485
486
487
488
489
490
491
492
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
493
494
495
496

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
497
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
498
499
    """

500
501
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
502
503
504
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
505

506
507
508
509
510
511
512
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
513
514
515
516
517
518
519
520
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
521
522
523
        return output_tensor


helloyongyang's avatar
helloyongyang committed
524
525
526
527
528
529
530
531
532
533
534
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

535
536
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
537
538
539
540
541
542
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
543
544
545
546
547
548
549
550
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
551
552
553
        return output_tensor


554
555
556
557
558
559
560
561
562
563
564
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

565
566
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
567
568
569
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
570
571

    def apply(self, input_tensor):
572
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
573
574
575
576
577
578
579
580
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
581
582
583
584
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
585
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
586
587
588
589
590
591
592
593
594
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

595
596
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
597
598
599
600
601
602
603
604
605
606
607
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
608
609
610
611
612
613
614
615
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            self.bias,
        )
616
        return output_tensor
617
618


Dongz's avatar
Dongz committed
619
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
620
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
621
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
622
623
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
624
625
    }

Dongz's avatar
Dongz committed
626
627
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
628
629
630
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
631
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
632
633

    weight_dict = {
Dongz's avatar
Dongz committed
634
635
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
636
637
    }

Dongz's avatar
Dongz committed
638
639
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
640
641
642
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
643
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
644
645

    weight_dict = {
Dongz's avatar
Dongz committed
646
647
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
648
649
    }

Dongz's avatar
Dongz committed
650
651
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
652
653
654
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
655
    logger.info(output_tensor.shape)