mm_weight.py 28.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
5
from lightx2v.utils.envs import *
root's avatar
root committed
6
from loguru import logger
Dongz's avatar
Dongz committed
7

gushiqiao's avatar
gushiqiao committed
8
9
10
11
12
13
14
15
16
17
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

18
19
20
21
try:
    import q8_kernels.functional as Q8F
except ImportError:
    Q8F = None
helloyongyang's avatar
helloyongyang committed
22

23
24
25
26
27
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
28
29
30
31
32
try:
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

helloyongyang's avatar
helloyongyang committed
33
34

class MMWeightTemplate(metaclass=ABCMeta):
gushiqiao's avatar
fix.  
gushiqiao committed
35
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
helloyongyang's avatar
helloyongyang committed
36
37
        self.weight_name = weight_name
        self.bias_name = bias_name
gushiqiao's avatar
fix.  
gushiqiao committed
38
39
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
helloyongyang's avatar
helloyongyang committed
40
41
42
43
44
45
46
47
48
49
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

50
51
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
52

gushiqiao's avatar
gushiqiao committed
53
54
    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)
55
56
        if hasattr(self, "weight_scale"):
            self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking)
57
        if hasattr(self, "bias") and self.bias is not None:
gushiqiao's avatar
gushiqiao committed
58
59
            self.bias = self.bias.cuda(non_blocking=non_blocking)

60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)

helloyongyang's avatar
helloyongyang committed
74

Dongz's avatar
Dongz committed
75
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
76
class MMWeight(MMWeightTemplate):
gushiqiao's avatar
fix.  
gushiqiao committed
77
78
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
79
80

    def load(self, weight_dict):
81
        self.weight = weight_dict[self.weight_name].t()
Xinchi Huang's avatar
Xinchi Huang committed
82
        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
83
        self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
Xinchi Huang's avatar
Xinchi Huang committed
84
        self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
helloyongyang's avatar
helloyongyang committed
85
86
87
88
89
90
91
92
93
94

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
95
96
97
98
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
99
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
100
101
102
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        return destination

helloyongyang's avatar
helloyongyang committed
103

Dongz's avatar
Dongz committed
104
@MM_WEIGHT_REGISTER("Default-Force-FP32")
105
class MMWeightForceFP32(MMWeight):
gushiqiao's avatar
fix.  
gushiqiao committed
106
107
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
108
109
110
111

    def load(self, weight_dict):
        super().load(weight_dict)
        self.weight = self.weight.to(torch.float32)
112
        if hasattr(self, "bias") and self.bias is not None:
helloyongyang's avatar
helloyongyang committed
113
114
115
            self.bias = self.bias.to(torch.float32)


116
class MMWeightQuantTemplate(MMWeightTemplate):
117
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
gushiqiao's avatar
fix.  
gushiqiao committed
118
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
119
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
120
121
122
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
123
124
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
125

helloyongyang's avatar
helloyongyang committed
126
127
128
    # =========================
    # weight load functions
    # =========================
129

130
131
132
133
134
    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
            self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
135
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
136
137
138
139
        else:
            self.weight = self.lazy_load_file.get_tensor(self.weight_name)
            self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
            if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
140
                self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
141

helloyongyang's avatar
helloyongyang committed
142
143
        if self.weight_need_transpose:
            self.weight = self.weight.t()
144

145
146
147
148
149
    def load(self, weight_dict):
        if not self.lazy_load:
            self.load_func(weight_dict)
            if self.weight_need_transpose:
                self.weight = self.weight.t()
gushiqiao's avatar
Fix  
gushiqiao committed
150
                self.pinned_weight = self.pinned_weight.t()
151
152

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
153
        attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
154
155
156
157
158
159
160
161
162
163
164
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)

    def _calculate_size(self):
        if self.bias is not None:
            return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()

        return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()

165
    def load_quantized(self, weight_dict):
166
        self.weight = weight_dict[self.weight_name]
167
168
169
170
        self.weight_scale = weight_dict[self.weight_scale_name].float()

        self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
        self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
171
172

    def load_fp8_perchannel_sym(self, weight_dict):
173
        if self.config.get("weight_auto_quant", False):
174
            self.weight = weight_dict[self.weight_name].to(torch.float32)
175
176
177
178
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
179
180
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
181
182
        else:
            self.load_quantized(weight_dict)
183
184
185
186
187
188

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
189
190

    def load_int8_perchannel_sym(self, weight_dict):
191
        if self.config.get("weight_auto_quant", False):
192
            self.weight = weight_dict[self.weight_name].to(torch.float32)
193
194
195
196
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
197
198
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
199
200
        else:
            self.load_quantized(weight_dict)
201
202
203
204
205
206

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
207
208

    def load_fp8_perblock128_sym(self, weight_dict):
209
        if self.config.get("weight_auto_quant", False):
210
            self.weight = weight_dict[self.weight_name]
211
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
212
213
            self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
214
215
        else:
            self.load_quantized(weight_dict)
216
217
218
219
220
221

        if self.bias_name is not None:
            self.bias = weight_dict[self.bias_name]
            self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
        else:
            self.bias = None
222
223
224
225

    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
226
227
228
229
230
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
231
232
233
234
235
236
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
237
238
239
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
240
241
242
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
270
271
272
273
274
275
276
277
278
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
279
280
        return input_tensor_quant, input_tensor_scale

281
282
283
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
helloyongyang's avatar
helloyongyang committed
284
285
286
287
        if self.weight_need_transpose:
            destination[self.weight_name] = self.weight.cpu().detach().clone().t().contiguous()
        else:
            destination[self.weight_name] = self.weight.cpu().detach().clone().contiguous()
288
        if hasattr(self, "bias") and self.bias is not None:
289
290
            destination[self.bias_name] = self.bias.cpu().detach().clone()
        if hasattr(self, "weight_scale"):
291
            destination[self.weight_name.removesuffix(".weight") + ".weight_scale"] = self.weight_scale.cpu().detach().clone()
292
293
        return destination

294

Dongz's avatar
Dongz committed
295
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
296
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
297
    """
helloyongyang's avatar
helloyongyang committed
298
299
300
301
302
303
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
304
305
    """

306
307
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
308
309
310
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
311
312
313
314
315
316

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
317
318

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
319
320
321
322
323
324
325
326
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
327
328
329
        return output_tensor


Dongz's avatar
Dongz committed
330
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
331
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
332
    """
helloyongyang's avatar
helloyongyang committed
333
334
335
336
337
338
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
339
340
    """

341
342
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
343
344
345
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
346
347
348
349
350
351

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
352
353

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
354
355
356
357
358
359
360
361
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
helloyongyang's avatar
helloyongyang committed
362
363
364
        return output_tensor


365
366
367
368
369
370
371
372
373
374
375
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

376
377
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
378
379
380
381
382
383
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
384
385
386
387
388
389
390
391
        output_tensor = Q8F.linear.fp8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            out_dtype=torch.bfloat16,
        )
392
393
394
        return output_tensor.squeeze(0)


Dongz's avatar
Dongz committed
395
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
396
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
397
    """
398
399
400
401
402
403
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
404
405
    """

406
407
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
408
409
410
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
411

412
413
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
414
415
416
417
418
419
420
421
422
        output_tensor = Q8F.linear.q8_linear(
            input_tensor_quant,
            self.weight,
            self.bias.float(),
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
            out_dtype=torch.bfloat16,
        )
423
424
425
        return output_tensor.squeeze(0)


426
427
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
428
    """
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 perchannel-pergroup group=128 dynamic sym
        Kernel: Deepgemm

    Reference: https://github.com/deepseek-ai/DeepGEMM

    Example:
        Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)

        Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
        Act Scale: torch.Size([1024, 16]), torch.float32
        Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
        Weight Scale: torch.Size([32, 16]), torch.float32
        Out : torch.Size([1024, 4096]), torch.bfloat16
    """

448
449
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
450
451
452
453
454
455
456
457
458
459
460
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
461
462
463
464
465
466
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

482
483
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
484
485
486
487
488
489
490
491
492
493
494
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
495
496
497
498
499
500
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
501
502
503
504
505
506
507
508
            output_tensor.add_(self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
509
510
511
512

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
513
        Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
Dongz's avatar
Dongz committed
514
515
    """

516
517
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
518
519
520
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
521

522
523
524
525
526
527
528
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
529
530
531
532
533
534
535
536
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            self.bias,
        )
537
538
539
        return output_tensor


helloyongyang's avatar
helloyongyang committed
540
541
542
543
544
545
546
547
548
549
550
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

551
552
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
helloyongyang's avatar
helloyongyang committed
553
554
555
556
557
558
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
559
560
561
562
563
564
565
566
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
helloyongyang's avatar
helloyongyang committed
567
568
569
        return output_tensor


570
571
572
573
574
575
576
577
578
579
580
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

581
582
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
583
584
585
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
586
587

    def apply(self, input_tensor):
588
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
589
590
591
592
593
594
595
596
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            bias=self.bias,
        )
597
598
599
600
        return output_tensor


@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
helloyongyang's avatar
helloyongyang committed
601
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
602
603
604
605
606
607
608
609
610
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

611
612
    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
613
614
615
616
617
618
619
620
621
622
623
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
624
625
626
627
628
629
630
631
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
            torch.bfloat16,
            self.bias,
        )
632
        return output_tensor
633
634


gushiqiao's avatar
gushiqiao committed
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

    def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


Dongz's avatar
Dongz committed
662
if __name__ == "__main__":
helloyongyang's avatar
helloyongyang committed
663
    weight_dict = {
helloyongyang's avatar
helloyongyang committed
664
        "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
Dongz's avatar
Dongz committed
665
666
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
        "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
helloyongyang's avatar
helloyongyang committed
667
668
    }

Dongz's avatar
Dongz committed
669
670
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": False})
helloyongyang's avatar
helloyongyang committed
671
672
673
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
674
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
675
676

    weight_dict = {
Dongz's avatar
Dongz committed
677
678
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
679
680
    }

Dongz's avatar
Dongz committed
681
682
    mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
683
684
685
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
686
    logger.info(output_tensor.shape)
helloyongyang's avatar
helloyongyang committed
687
688

    weight_dict = {
Dongz's avatar
Dongz committed
689
690
        "xx.weight": torch.randn(8192, 4096),
        "xx.bias": torch.randn(8192).to(torch.bfloat16),
helloyongyang's avatar
helloyongyang committed
691
692
    }

Dongz's avatar
Dongz committed
693
694
    mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
    mm_weight.set_config({"weight_auto_quant": True})
helloyongyang's avatar
helloyongyang committed
695
696
697
    mm_weight.load(weight_dict)
    input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
    output_tensor = mm_weight.apply(input_tensor)
root's avatar
root committed
698
    logger.info(output_tensor.shape)